1 use crate::p2::{OutputStream, Pollable, StreamError};
2 use bytes::Bytes;
3 use std::pin::pin;
4 use std::sync::{Arc, Mutex};
5 use std::task::{Context, Poll, Waker};
6 use wasmtime::format_err;
7 
8 #[derive(Debug)]
9 struct WorkerState {
10     alive: bool,
11     items: std::collections::VecDeque<Bytes>,
12     write_budget: usize,
13     flush_pending: bool,
14     error: Option<wasmtime::Error>,
15     write_ready_changed: Option<Waker>,
16 }
17 
18 impl WorkerState {
check_error(&mut self) -> Result<(), StreamError>19     fn check_error(&mut self) -> Result<(), StreamError> {
20         if let Some(e) = self.error.take() {
21             return Err(StreamError::LastOperationFailed(e));
22         }
23         if !self.alive {
24             return Err(StreamError::Closed);
25         }
26         Ok(())
27     }
28 }
29 
30 struct Worker {
31     state: Mutex<WorkerState>,
32     new_work: tokio::sync::Notify,
33 }
34 
35 enum Job {
36     Flush,
37     Write(Bytes),
38 }
39 
40 impl Worker {
new(write_budget: usize) -> Self41     fn new(write_budget: usize) -> Self {
42         Self {
43             state: Mutex::new(WorkerState {
44                 alive: true,
45                 items: std::collections::VecDeque::new(),
46                 write_budget,
47                 flush_pending: false,
48                 error: None,
49                 write_ready_changed: None,
50             }),
51             new_work: tokio::sync::Notify::new(),
52         }
53     }
check_write(&self) -> Result<usize, StreamError>54     fn check_write(&self) -> Result<usize, StreamError> {
55         let mut state = self.state();
56         if let Err(e) = state.check_error() {
57             return Err(e);
58         }
59 
60         if state.flush_pending || state.write_budget == 0 {
61             return Ok(0);
62         }
63 
64         Ok(state.write_budget)
65     }
state(&self) -> std::sync::MutexGuard<'_, WorkerState>66     fn state(&self) -> std::sync::MutexGuard<'_, WorkerState> {
67         self.state.lock().unwrap()
68     }
pop(&self) -> Option<Job>69     fn pop(&self) -> Option<Job> {
70         let mut state = self.state();
71         if state.items.is_empty() {
72             if state.flush_pending {
73                 return Some(Job::Flush);
74             }
75         } else if let Some(bytes) = state.items.pop_front() {
76             return Some(Job::Write(bytes));
77         }
78 
79         None
80     }
report_error(&self, e: std::io::Error)81     fn report_error(&self, e: std::io::Error) {
82         let waker = {
83             let mut state = self.state();
84             state.alive = false;
85             state.error = Some(e.into());
86             state.flush_pending = false;
87             state.write_ready_changed.take()
88         };
89         if let Some(waker) = waker {
90             waker.wake();
91         }
92     }
work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T)93     async fn work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T) {
94         use tokio::io::AsyncWriteExt;
95         let mut writer = pin!(writer);
96         loop {
97             while let Some(job) = self.pop() {
98                 match job {
99                     Job::Flush => {
100                         if let Err(e) = writer.flush().await {
101                             self.report_error(e);
102                             return;
103                         }
104 
105                         tracing::debug!("worker marking flush complete");
106                         self.state().flush_pending = false;
107                     }
108 
109                     Job::Write(mut bytes) => {
110                         tracing::debug!("worker writing: {bytes:?}");
111                         let len = bytes.len();
112                         match writer.write_all_buf(&mut bytes).await {
113                             Err(e) => {
114                                 self.report_error(e);
115                                 return;
116                             }
117                             Ok(_) => {
118                                 self.state().write_budget += len;
119                             }
120                         }
121                     }
122                 }
123 
124                 let waker = self.state().write_ready_changed.take();
125                 if let Some(waker) = waker {
126                     waker.wake();
127                 }
128             }
129             self.new_work.notified().await;
130         }
131     }
132 }
133 
134 /// Provides a [`OutputStream`] impl from a [`tokio::io::AsyncWrite`] impl
135 pub struct AsyncWriteStream {
136     worker: Arc<Worker>,
137     join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
138 }
139 
140 impl AsyncWriteStream {
141     /// Create a [`AsyncWriteStream`]. In order to use the [`OutputStream`] impl
142     /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`].
new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self143     pub fn new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self {
144         let worker = Arc::new(Worker::new(write_budget));
145 
146         let w = Arc::clone(&worker);
147         let join_handle = crate::runtime::spawn(async move { w.work(writer).await });
148 
149         AsyncWriteStream {
150             worker,
151             join_handle: Some(join_handle),
152         }
153     }
154 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>155     pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
156         let mut state = self.worker.state();
157         if state.error.is_some() || !state.alive || (!state.flush_pending && state.write_budget > 0)
158         {
159             return Poll::Ready(());
160         }
161         state.write_ready_changed = Some(cx.waker().clone());
162         Poll::Pending
163     }
164 }
165 
166 #[async_trait::async_trait]
167 impl OutputStream for AsyncWriteStream {
write(&mut self, bytes: Bytes) -> Result<(), StreamError>168     fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
169         let mut state = self.worker.state();
170         state.check_error()?;
171         if state.flush_pending {
172             return Err(StreamError::Trap(format_err!(
173                 "write not permitted while flush pending"
174             )));
175         }
176         match state.write_budget.checked_sub(bytes.len()) {
177             Some(remaining_budget) => {
178                 state.write_budget = remaining_budget;
179                 state.items.push_back(bytes);
180             }
181             None => return Err(StreamError::Trap(format_err!("write exceeded budget"))),
182         }
183         drop(state);
184         self.worker.new_work.notify_one();
185         Ok(())
186     }
flush(&mut self) -> Result<(), StreamError>187     fn flush(&mut self) -> Result<(), StreamError> {
188         let mut state = self.worker.state();
189         state.check_error()?;
190 
191         state.flush_pending = true;
192         self.worker.new_work.notify_one();
193 
194         Ok(())
195     }
196 
check_write(&mut self) -> Result<usize, StreamError>197     fn check_write(&mut self) -> Result<usize, StreamError> {
198         self.worker.check_write()
199     }
200 
cancel(&mut self)201     async fn cancel(&mut self) {
202         match self.join_handle.take() {
203             Some(task) => _ = task.cancel().await,
204             None => {}
205         }
206     }
207 }
208 #[async_trait::async_trait]
209 impl Pollable for AsyncWriteStream {
ready(&mut self)210     async fn ready(&mut self) {
211         std::future::poll_fn(|cx| self.poll_ready(cx)).await
212     }
213 }
214