1 use crate::cli::{IsTerminal, StdinStream, StdoutStream}; 2 use crate::p2; 3 use bytes::Bytes; 4 use std::mem; 5 use std::pin::Pin; 6 use std::sync::Arc; 7 use std::task::{Context, Poll, ready}; 8 use tokio::io::{self, AsyncRead, AsyncWrite}; 9 use tokio::sync::{Mutex, OwnedMutexGuard}; 10 use wasmtime_wasi_io::streams::{InputStream, OutputStream}; 11 12 trait SharedHandleReady: Send + Sync + 'static { poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>13 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; 14 } 15 16 impl SharedHandleReady for p2::pipe::AsyncWriteStream { poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>17 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { 18 <Self>::poll_ready(self, cx) 19 } 20 } 21 22 impl SharedHandleReady for p2::pipe::AsyncReadStream { poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>23 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { 24 <Self>::poll_ready(self, cx) 25 } 26 } 27 28 /// An impl of [`StdinStream`] built on top of [`AsyncRead`]. 29 // 30 // Note the usage of `tokio::sync::Mutex` here as opposed to a 31 // `std::sync::Mutex`. This is intentionally done to implement the `Pollable` 32 // variant of this trait. Note that in doing so we're left with the quandry of 33 // how to implement methods of `InputStream` since those methods are not 34 // `async`. They're currently implemented with `try_lock`, which then raises the 35 // question of what to do on contention. Currently traps are returned. 36 // 37 // Why should it be ok to return a trap? In general concurrency/contention 38 // shouldn't return a trap since it should be able to happen normally. The 39 // current assumption, though, is that WASI stdin/stdout streams are special 40 // enough that the contention case should never come up in practice. Currently 41 // in WASI there is no actually concurrency, there's just the items in a single 42 // `Store` and that store owns all of its I/O in a single Tokio task. There's no 43 // means to actually spawn multiple Tokio tasks that use the same store. This 44 // means at the very least that there's zero parallelism. Due to the lack of 45 // multiple tasks that also means that there's no concurrency either. 46 // 47 // This `AsyncStdinStream` wrapper is only intended to be used by the WASI 48 // bindings themselves. It's possible for the host to take this and work with it 49 // on its own task, but that's niche enough it's not designed for. 50 // 51 // Overall that means that the guest is either calling `Pollable` or 52 // `InputStream` methods. This means that there should never be contention 53 // between the two at this time. This may all change in the future with WASI 54 // 0.3, but perhaps we'll have a better story for stdio at that time (see the 55 // doc block on the `OutputStream` impl below) 56 pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>); 57 58 impl AsyncStdinStream { new(s: impl AsyncRead + Send + Sync + 'static) -> Self59 pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self { 60 Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s)))) 61 } 62 } 63 64 impl StdinStream for AsyncStdinStream { p2_stream(&self) -> Box<dyn InputStream>65 fn p2_stream(&self) -> Box<dyn InputStream> { 66 Box::new(Self(self.0.clone())) 67 } async_stream(&self) -> Box<dyn AsyncRead + Send + Sync>68 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> { 69 Box::new(StdioHandle::Ready(self.0.clone())) 70 } 71 } 72 73 impl IsTerminal for AsyncStdinStream { is_terminal(&self) -> bool74 fn is_terminal(&self) -> bool { 75 false 76 } 77 } 78 79 #[async_trait::async_trait] 80 impl InputStream for AsyncStdinStream { read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError>81 fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> { 82 match self.0.try_lock() { 83 Ok(mut stream) => stream.read(size), 84 Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")), 85 } 86 } skip(&mut self, size: usize) -> Result<usize, p2::StreamError>87 fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> { 88 match self.0.try_lock() { 89 Ok(mut stream) => stream.skip(size), 90 Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")), 91 } 92 } cancel(&mut self)93 async fn cancel(&mut self) { 94 // Cancel the inner stream if we're the last reference to it: 95 if let Some(mutex) = Arc::get_mut(&mut self.0) { 96 match mutex.try_lock() { 97 Ok(mut stream) => stream.cancel().await, 98 Err(_) => {} 99 } 100 } 101 } 102 } 103 104 #[async_trait::async_trait] 105 impl p2::Pollable for AsyncStdinStream { ready(&mut self)106 async fn ready(&mut self) { 107 self.0.lock().await.ready().await 108 } 109 } 110 111 impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>, ) -> Poll<io::Result<()>>112 fn poll_read( 113 mut self: Pin<&mut Self>, 114 cx: &mut Context<'_>, 115 buf: &mut io::ReadBuf<'_>, 116 ) -> Poll<io::Result<()>> { 117 match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) { 118 Some(Ok(bytes)) => { 119 buf.put_slice(&bytes); 120 Poll::Ready(Ok(())) 121 } 122 Some(Err(e)) => Poll::Ready(Err(e)), 123 // If the guard can't be acquired that means that this stream is 124 // closed, so return that we're ready without filling in data. 125 None => Poll::Ready(Ok(())), 126 } 127 } 128 } 129 130 /// A wrapper of [`crate::p2::pipe::AsyncWriteStream`] that implements 131 /// [`StdoutStream`]. Note that the [`OutputStream`] impl for this is not 132 /// correct when used for interleaved async IO. 133 // 134 // Note that the use of `tokio::sync::Mutex` here is intentional, in addition to 135 // the `try_lock()` calls below in the implementation of `OutputStream`. For 136 // more information see the documentation on `AsyncStdinStream`. 137 pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>); 138 139 impl AsyncStdoutStream { new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self140 pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self { 141 Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new( 142 budget, s, 143 )))) 144 } 145 } 146 147 impl StdoutStream for AsyncStdoutStream { p2_stream(&self) -> Box<dyn OutputStream>148 fn p2_stream(&self) -> Box<dyn OutputStream> { 149 Box::new(Self(self.0.clone())) 150 } async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync>151 fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> { 152 Box::new(StdioHandle::Ready(self.0.clone())) 153 } 154 } 155 156 impl IsTerminal for AsyncStdoutStream { is_terminal(&self) -> bool157 fn is_terminal(&self) -> bool { 158 false 159 } 160 } 161 162 // This implementation is known to be bogus. All check-writes and writes are 163 // directed at the same underlying stream. The check-write/write protocol does 164 // require the size returned by a check-write to be accepted by write, even if 165 // other side-effects happen between those calls, and this implementation 166 // permits another view (created by StdoutStream::stream()) of the same 167 // underlying stream to accept a write which will invalidate a prior 168 // check-write of another view. 169 // Ultimately, the Std{in,out}Stream::stream() methods exist because many 170 // different places in a linked component (which may itself contain many 171 // modules) may need to access stdio without any coordination to keep those 172 // accesses all using pointing to the same resource. So, we allow many 173 // resources to be created. We have the reasonable expectation that programs 174 // won't attempt to interleave async IO from these disparate uses of stdio. 175 // If that expectation doesn't turn out to be true, and you find yourself at 176 // this comment to correct it: sorry about that. 177 #[async_trait::async_trait] 178 impl OutputStream for AsyncStdoutStream { check_write(&mut self) -> Result<usize, p2::StreamError>179 fn check_write(&mut self) -> Result<usize, p2::StreamError> { 180 match self.0.try_lock() { 181 Ok(mut stream) => stream.check_write(), 182 Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")), 183 } 184 } write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError>185 fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> { 186 match self.0.try_lock() { 187 Ok(mut stream) => stream.write(bytes), 188 Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")), 189 } 190 } flush(&mut self) -> Result<(), p2::StreamError>191 fn flush(&mut self) -> Result<(), p2::StreamError> { 192 match self.0.try_lock() { 193 Ok(mut stream) => stream.flush(), 194 Err(_) => Err(p2::StreamError::trap( 195 "concurrent flushes not supported yet", 196 )), 197 } 198 } cancel(&mut self)199 async fn cancel(&mut self) { 200 // Cancel the inner stream if we're the last reference to it: 201 if let Some(mutex) = Arc::get_mut(&mut self.0) { 202 match mutex.try_lock() { 203 Ok(mut stream) => stream.cancel().await, 204 Err(_) => {} 205 } 206 } 207 } 208 } 209 210 #[async_trait::async_trait] 211 impl p2::Pollable for AsyncStdoutStream { ready(&mut self)212 async fn ready(&mut self) { 213 self.0.lock().await.ready().await 214 } 215 } 216 217 impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>218 fn poll_write( 219 self: Pin<&mut Self>, 220 cx: &mut Context<'_>, 221 buf: &[u8], 222 ) -> Poll<io::Result<usize>> { 223 match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) { 224 Some(Ok(())) => Poll::Ready(Ok(buf.len())), 225 Some(Err(e)) => Poll::Ready(Err(e)), 226 None => Poll::Ready(Ok(0)), 227 } 228 } poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>229 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 230 match ready!(self.poll(cx, |i| i.flush())) { 231 Some(result) => Poll::Ready(result), 232 None => Poll::Ready(Ok(())), 233 } 234 } poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>>235 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { 236 Poll::Ready(Ok(())) 237 } 238 } 239 240 /// State necessary for effectively transforming `Arc<Mutex<dyn 241 /// {Input,Output}Stream>>` into `Async{Read,Write}`. 242 /// 243 /// This is a beast and inefficient. It should get the job done in theory but 244 /// one must truly ask oneself at some point "but at what cost". 245 /// 246 /// More seriously, it's unclear if this is the best way to transform a single 247 /// `AsyncRead` into a "multiple `AsyncRead`". This certainly is an attempt and 248 /// the hope is that everything here is private enough that we can refactor as 249 /// necessary in the future without causing much churn. 250 enum StdioHandle<S> { 251 Ready(Arc<Mutex<S>>), 252 Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>), 253 Locked(OwnedMutexGuard<S>), 254 Closed, 255 } 256 257 impl<S> StdioHandle<S> 258 where 259 S: SharedHandleReady, 260 { poll<T>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, op: impl FnOnce(&mut S) -> p2::StreamResult<T>, ) -> Poll<Option<io::Result<T>>>261 fn poll<T>( 262 mut self: Pin<&mut Self>, 263 cx: &mut Context<'_>, 264 op: impl FnOnce(&mut S) -> p2::StreamResult<T>, 265 ) -> Poll<Option<io::Result<T>>> { 266 // If we don't currently have the lock on this handle, initiate the 267 // lock acquisition. 268 if let StdioHandle::Ready(lock) = &*self { 269 self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned()))); 270 } 271 272 // If we're in the process of locking this handle, wait for that to 273 // finish. 274 if let Some(lock) = self.as_mut().as_locking() { 275 let guard = ready!(lock.poll(cx)); 276 self.set(StdioHandle::Locked(guard)); 277 } 278 279 let mut guard = match self.as_mut().take_guard() { 280 Some(guard) => guard, 281 // If the guard can't be acquired that means that this stream is 282 // closed, so return that we're ready without filling in data. 283 None => return Poll::Ready(None), 284 }; 285 286 // Wait for our locked stream to be ready, resetting to the "locked" 287 // state if it's not quite ready yet. 288 match guard.poll_ready(cx) { 289 Poll::Ready(()) => {} 290 291 // If the read isn't ready yet then restore our "locked" state 292 // since we haven't finished, then return pending. 293 Poll::Pending => { 294 self.set(StdioHandle::Locked(guard)); 295 return Poll::Pending; 296 } 297 } 298 299 // Perform the I/O and delegate on the result. 300 match op(&mut guard) { 301 // The I/O succeeded so relinquish the lock on this stream by 302 // transitioning back to the "Ready" state. 303 Ok(result) => { 304 self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone())); 305 Poll::Ready(Some(Ok(result))) 306 } 307 308 // The stream is closed, and `take_guard` above already set the 309 // closed state, so return nothing indicating the closure. 310 Err(p2::StreamError::Closed) => Poll::Ready(None), 311 312 // The stream failed so propagate the error. Errors should only 313 // come from the underlying I/O object and thus should cast 314 // successfully. Additionally `take_guard` replaced our state 315 // with "closed" above which is the desired state at this point. 316 Err(p2::StreamError::LastOperationFailed(e)) => { 317 Poll::Ready(Some(Err(e.downcast().unwrap()))) 318 } 319 320 // Shouldn't be possible to produce a trap here. 321 Err(p2::StreamError::Trap(_)) => unreachable!(), 322 } 323 } 324 as_locking( self: Pin<&mut Self>, ) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>>325 fn as_locking( 326 self: Pin<&mut Self>, 327 ) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> { 328 // SAFETY: this is a pin-projection from `self` into the `Locking` 329 // field. 330 unsafe { 331 match self.get_unchecked_mut() { 332 StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)), 333 _ => None, 334 } 335 } 336 } 337 take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>>338 fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> { 339 if !matches!(*self, StdioHandle::Locked(_)) { 340 return None; 341 } 342 // SAFETY: the `Locked` arm is safe to move as it's an invariant of this 343 // type that it's not pinned. 344 unsafe { 345 match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) { 346 StdioHandle::Locked(guard) => Some(guard), 347 _ => unreachable!(), 348 } 349 } 350 } 351 } 352