1 use crate::p2::{OutputStream, Pollable, StreamError}; 2 use bytes::Bytes; 3 use std::pin::pin; 4 use std::sync::{Arc, Mutex}; 5 use std::task::{Context, Poll, Waker}; 6 use wasmtime::format_err; 7 8 #[derive(Debug)] 9 struct WorkerState { 10 alive: bool, 11 items: std::collections::VecDeque<Bytes>, 12 write_budget: usize, 13 flush_pending: bool, 14 error: Option<wasmtime::Error>, 15 write_ready_changed: Option<Waker>, 16 } 17 18 impl WorkerState { check_error(&mut self) -> Result<(), StreamError>19 fn check_error(&mut self) -> Result<(), StreamError> { 20 if let Some(e) = self.error.take() { 21 return Err(StreamError::LastOperationFailed(e)); 22 } 23 if !self.alive { 24 return Err(StreamError::Closed); 25 } 26 Ok(()) 27 } 28 } 29 30 struct Worker { 31 state: Mutex<WorkerState>, 32 new_work: tokio::sync::Notify, 33 } 34 35 enum Job { 36 Flush, 37 Write(Bytes), 38 } 39 40 impl Worker { new(write_budget: usize) -> Self41 fn new(write_budget: usize) -> Self { 42 Self { 43 state: Mutex::new(WorkerState { 44 alive: true, 45 items: std::collections::VecDeque::new(), 46 write_budget, 47 flush_pending: false, 48 error: None, 49 write_ready_changed: None, 50 }), 51 new_work: tokio::sync::Notify::new(), 52 } 53 } check_write(&self) -> Result<usize, StreamError>54 fn check_write(&self) -> Result<usize, StreamError> { 55 let mut state = self.state(); 56 if let Err(e) = state.check_error() { 57 return Err(e); 58 } 59 60 if state.flush_pending || state.write_budget == 0 { 61 return Ok(0); 62 } 63 64 Ok(state.write_budget) 65 } state(&self) -> std::sync::MutexGuard<'_, WorkerState>66 fn state(&self) -> std::sync::MutexGuard<'_, WorkerState> { 67 self.state.lock().unwrap() 68 } pop(&self) -> Option<Job>69 fn pop(&self) -> Option<Job> { 70 let mut state = self.state(); 71 if state.items.is_empty() { 72 if state.flush_pending { 73 return Some(Job::Flush); 74 } 75 } else if let Some(bytes) = state.items.pop_front() { 76 return Some(Job::Write(bytes)); 77 } 78 79 None 80 } report_error(&self, e: std::io::Error)81 fn report_error(&self, e: std::io::Error) { 82 let waker = { 83 let mut state = self.state(); 84 state.alive = false; 85 state.error = Some(e.into()); 86 state.flush_pending = false; 87 state.write_ready_changed.take() 88 }; 89 if let Some(waker) = waker { 90 waker.wake(); 91 } 92 } work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T)93 async fn work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T) { 94 use tokio::io::AsyncWriteExt; 95 let mut writer = pin!(writer); 96 loop { 97 while let Some(job) = self.pop() { 98 match job { 99 Job::Flush => { 100 if let Err(e) = writer.flush().await { 101 self.report_error(e); 102 return; 103 } 104 105 tracing::debug!("worker marking flush complete"); 106 self.state().flush_pending = false; 107 } 108 109 Job::Write(mut bytes) => { 110 tracing::debug!("worker writing: {bytes:?}"); 111 let len = bytes.len(); 112 match writer.write_all_buf(&mut bytes).await { 113 Err(e) => { 114 self.report_error(e); 115 return; 116 } 117 Ok(_) => { 118 self.state().write_budget += len; 119 } 120 } 121 } 122 } 123 124 let waker = self.state().write_ready_changed.take(); 125 if let Some(waker) = waker { 126 waker.wake(); 127 } 128 } 129 self.new_work.notified().await; 130 } 131 } 132 } 133 134 /// Provides a [`OutputStream`] impl from a [`tokio::io::AsyncWrite`] impl 135 pub struct AsyncWriteStream { 136 worker: Arc<Worker>, 137 join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>, 138 } 139 140 impl AsyncWriteStream { 141 /// Create a [`AsyncWriteStream`]. In order to use the [`OutputStream`] impl 142 /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`]. new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self143 pub fn new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self { 144 let worker = Arc::new(Worker::new(write_budget)); 145 146 let w = Arc::clone(&worker); 147 let join_handle = crate::runtime::spawn(async move { w.work(writer).await }); 148 149 AsyncWriteStream { 150 worker, 151 join_handle: Some(join_handle), 152 } 153 } 154 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>155 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { 156 let mut state = self.worker.state(); 157 if state.error.is_some() || !state.alive || (!state.flush_pending && state.write_budget > 0) 158 { 159 return Poll::Ready(()); 160 } 161 state.write_ready_changed = Some(cx.waker().clone()); 162 Poll::Pending 163 } 164 } 165 166 #[async_trait::async_trait] 167 impl OutputStream for AsyncWriteStream { write(&mut self, bytes: Bytes) -> Result<(), StreamError>168 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> { 169 let mut state = self.worker.state(); 170 state.check_error()?; 171 if state.flush_pending { 172 return Err(StreamError::Trap(format_err!( 173 "write not permitted while flush pending" 174 ))); 175 } 176 match state.write_budget.checked_sub(bytes.len()) { 177 Some(remaining_budget) => { 178 state.write_budget = remaining_budget; 179 state.items.push_back(bytes); 180 } 181 None => return Err(StreamError::Trap(format_err!("write exceeded budget"))), 182 } 183 drop(state); 184 self.worker.new_work.notify_one(); 185 Ok(()) 186 } flush(&mut self) -> Result<(), StreamError>187 fn flush(&mut self) -> Result<(), StreamError> { 188 let mut state = self.worker.state(); 189 state.check_error()?; 190 191 state.flush_pending = true; 192 self.worker.new_work.notify_one(); 193 194 Ok(()) 195 } 196 check_write(&mut self) -> Result<usize, StreamError>197 fn check_write(&mut self) -> Result<usize, StreamError> { 198 self.worker.check_write() 199 } 200 cancel(&mut self)201 async fn cancel(&mut self) { 202 match self.join_handle.take() { 203 Some(task) => _ = task.cancel().await, 204 None => {} 205 } 206 } 207 } 208 #[async_trait::async_trait] 209 impl Pollable for AsyncWriteStream { ready(&mut self)210 async fn ready(&mut self) { 211 std::future::poll_fn(|cx| self.poll_ready(cx)).await 212 } 213 } 214