1 use crate::cli::{IsTerminal, StdinStream, StdoutStream};
2 use crate::p2;
3 use bytes::Bytes;
4 use std::mem;
5 use std::pin::Pin;
6 use std::sync::Arc;
7 use std::task::{Context, Poll, ready};
8 use tokio::io::{self, AsyncRead, AsyncWrite};
9 use tokio::sync::{Mutex, OwnedMutexGuard};
10 use wasmtime_wasi_io::streams::{InputStream, OutputStream};
11 
12 trait SharedHandleReady: Send + Sync + 'static {
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>13     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
14 }
15 
16 impl SharedHandleReady for p2::pipe::AsyncWriteStream {
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>17     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
18         <Self>::poll_ready(self, cx)
19     }
20 }
21 
22 impl SharedHandleReady for p2::pipe::AsyncReadStream {
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>23     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
24         <Self>::poll_ready(self, cx)
25     }
26 }
27 
28 /// An impl of [`StdinStream`] built on top of [`AsyncRead`].
29 //
30 // Note the usage of `tokio::sync::Mutex` here as opposed to a
31 // `std::sync::Mutex`. This is intentionally done to implement the `Pollable`
32 // variant of this trait. Note that in doing so we're left with the quandry of
33 // how to implement methods of `InputStream` since those methods are not
34 // `async`. They're currently implemented with `try_lock`, which then raises the
35 // question of what to do on contention. Currently traps are returned.
36 //
37 // Why should it be ok to return a trap? In general concurrency/contention
38 // shouldn't return a trap since it should be able to happen normally. The
39 // current assumption, though, is that WASI stdin/stdout streams are special
40 // enough that the contention case should never come up in practice. Currently
41 // in WASI there is no actually concurrency, there's just the items in a single
42 // `Store` and that store owns all of its I/O in a single Tokio task. There's no
43 // means to actually spawn multiple Tokio tasks that use the same store. This
44 // means at the very least that there's zero parallelism. Due to the lack of
45 // multiple tasks that also means that there's no concurrency either.
46 //
47 // This `AsyncStdinStream` wrapper is only intended to be used by the WASI
48 // bindings themselves. It's possible for the host to take this and work with it
49 // on its own task, but that's niche enough it's not designed for.
50 //
51 // Overall that means that the guest is either calling `Pollable` or
52 // `InputStream` methods. This means that there should never be contention
53 // between the two at this time. This may all change in the future with WASI
54 // 0.3, but perhaps we'll have a better story for stdio at that time (see the
55 // doc block on the `OutputStream` impl below)
56 pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>);
57 
58 impl AsyncStdinStream {
new(s: impl AsyncRead + Send + Sync + 'static) -> Self59     pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self {
60         Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s))))
61     }
62 }
63 
64 impl StdinStream for AsyncStdinStream {
p2_stream(&self) -> Box<dyn InputStream>65     fn p2_stream(&self) -> Box<dyn InputStream> {
66         Box::new(Self(self.0.clone()))
67     }
async_stream(&self) -> Box<dyn AsyncRead + Send + Sync>68     fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
69         Box::new(StdioHandle::Ready(self.0.clone()))
70     }
71 }
72 
73 impl IsTerminal for AsyncStdinStream {
is_terminal(&self) -> bool74     fn is_terminal(&self) -> bool {
75         false
76     }
77 }
78 
79 #[async_trait::async_trait]
80 impl InputStream for AsyncStdinStream {
read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError>81     fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> {
82         match self.0.try_lock() {
83             Ok(mut stream) => stream.read(size),
84             Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")),
85         }
86     }
skip(&mut self, size: usize) -> Result<usize, p2::StreamError>87     fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> {
88         match self.0.try_lock() {
89             Ok(mut stream) => stream.skip(size),
90             Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")),
91         }
92     }
cancel(&mut self)93     async fn cancel(&mut self) {
94         // Cancel the inner stream if we're the last reference to it:
95         if let Some(mutex) = Arc::get_mut(&mut self.0) {
96             match mutex.try_lock() {
97                 Ok(mut stream) => stream.cancel().await,
98                 Err(_) => {}
99             }
100         }
101     }
102 }
103 
104 #[async_trait::async_trait]
105 impl p2::Pollable for AsyncStdinStream {
ready(&mut self)106     async fn ready(&mut self) {
107         self.0.lock().await.ready().await
108     }
109 }
110 
111 impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>, ) -> Poll<io::Result<()>>112     fn poll_read(
113         mut self: Pin<&mut Self>,
114         cx: &mut Context<'_>,
115         buf: &mut io::ReadBuf<'_>,
116     ) -> Poll<io::Result<()>> {
117         match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) {
118             Some(Ok(bytes)) => {
119                 buf.put_slice(&bytes);
120                 Poll::Ready(Ok(()))
121             }
122             Some(Err(e)) => Poll::Ready(Err(e)),
123             // If the guard can't be acquired that means that this stream is
124             // closed, so return that we're ready without filling in data.
125             None => Poll::Ready(Ok(())),
126         }
127     }
128 }
129 
130 /// A wrapper of [`crate::p2::pipe::AsyncWriteStream`] that implements
131 /// [`StdoutStream`]. Note that the [`OutputStream`] impl for this is not
132 /// correct when used for interleaved async IO.
133 //
134 // Note that the use of `tokio::sync::Mutex` here is intentional, in addition to
135 // the `try_lock()` calls below in the implementation of `OutputStream`. For
136 // more information see the documentation on `AsyncStdinStream`.
137 pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>);
138 
139 impl AsyncStdoutStream {
new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self140     pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self {
141         Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new(
142             budget, s,
143         ))))
144     }
145 }
146 
147 impl StdoutStream for AsyncStdoutStream {
p2_stream(&self) -> Box<dyn OutputStream>148     fn p2_stream(&self) -> Box<dyn OutputStream> {
149         Box::new(Self(self.0.clone()))
150     }
async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync>151     fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
152         Box::new(StdioHandle::Ready(self.0.clone()))
153     }
154 }
155 
156 impl IsTerminal for AsyncStdoutStream {
is_terminal(&self) -> bool157     fn is_terminal(&self) -> bool {
158         false
159     }
160 }
161 
162 // This implementation is known to be bogus. All check-writes and writes are
163 // directed at the same underlying stream. The check-write/write protocol does
164 // require the size returned by a check-write to be accepted by write, even if
165 // other side-effects happen between those calls, and this implementation
166 // permits another view (created by StdoutStream::stream()) of the same
167 // underlying stream to accept a write which will invalidate a prior
168 // check-write of another view.
169 // Ultimately, the Std{in,out}Stream::stream() methods exist because many
170 // different places in a linked component (which may itself contain many
171 // modules) may need to access stdio without any coordination to keep those
172 // accesses all using pointing to the same resource. So, we allow many
173 // resources to be created. We have the reasonable expectation that programs
174 // won't attempt to interleave async IO from these disparate uses of stdio.
175 // If that expectation doesn't turn out to be true, and you find yourself at
176 // this comment to correct it: sorry about that.
177 #[async_trait::async_trait]
178 impl OutputStream for AsyncStdoutStream {
check_write(&mut self) -> Result<usize, p2::StreamError>179     fn check_write(&mut self) -> Result<usize, p2::StreamError> {
180         match self.0.try_lock() {
181             Ok(mut stream) => stream.check_write(),
182             Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")),
183         }
184     }
write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError>185     fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> {
186         match self.0.try_lock() {
187             Ok(mut stream) => stream.write(bytes),
188             Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")),
189         }
190     }
flush(&mut self) -> Result<(), p2::StreamError>191     fn flush(&mut self) -> Result<(), p2::StreamError> {
192         match self.0.try_lock() {
193             Ok(mut stream) => stream.flush(),
194             Err(_) => Err(p2::StreamError::trap(
195                 "concurrent flushes not supported yet",
196             )),
197         }
198     }
cancel(&mut self)199     async fn cancel(&mut self) {
200         // Cancel the inner stream if we're the last reference to it:
201         if let Some(mutex) = Arc::get_mut(&mut self.0) {
202             match mutex.try_lock() {
203                 Ok(mut stream) => stream.cancel().await,
204                 Err(_) => {}
205             }
206         }
207     }
208 }
209 
210 #[async_trait::async_trait]
211 impl p2::Pollable for AsyncStdoutStream {
ready(&mut self)212     async fn ready(&mut self) {
213         self.0.lock().await.ready().await
214     }
215 }
216 
217 impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>218     fn poll_write(
219         self: Pin<&mut Self>,
220         cx: &mut Context<'_>,
221         buf: &[u8],
222     ) -> Poll<io::Result<usize>> {
223         match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) {
224             Some(Ok(())) => Poll::Ready(Ok(buf.len())),
225             Some(Err(e)) => Poll::Ready(Err(e)),
226             None => Poll::Ready(Ok(0)),
227         }
228     }
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>229     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230         match ready!(self.poll(cx, |i| i.flush())) {
231             Some(result) => Poll::Ready(result),
232             None => Poll::Ready(Ok(())),
233         }
234     }
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>>235     fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236         Poll::Ready(Ok(()))
237     }
238 }
239 
240 /// State necessary for effectively transforming `Arc<Mutex<dyn
241 /// {Input,Output}Stream>>` into `Async{Read,Write}`.
242 ///
243 /// This is a beast and inefficient. It should get the job done in theory but
244 /// one must truly ask oneself at some point "but at what cost".
245 ///
246 /// More seriously, it's unclear if this is the best way to transform a single
247 /// `AsyncRead` into a "multiple `AsyncRead`". This certainly is an attempt and
248 /// the hope is that everything here is private enough that we can refactor as
249 /// necessary in the future without causing much churn.
250 enum StdioHandle<S> {
251     Ready(Arc<Mutex<S>>),
252     Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>),
253     Locked(OwnedMutexGuard<S>),
254     Closed,
255 }
256 
257 impl<S> StdioHandle<S>
258 where
259     S: SharedHandleReady,
260 {
poll<T>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, op: impl FnOnce(&mut S) -> p2::StreamResult<T>, ) -> Poll<Option<io::Result<T>>>261     fn poll<T>(
262         mut self: Pin<&mut Self>,
263         cx: &mut Context<'_>,
264         op: impl FnOnce(&mut S) -> p2::StreamResult<T>,
265     ) -> Poll<Option<io::Result<T>>> {
266         // If we don't currently have the lock on this handle, initiate the
267         // lock acquisition.
268         if let StdioHandle::Ready(lock) = &*self {
269             self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned())));
270         }
271 
272         // If we're in the process of locking this handle, wait for that to
273         // finish.
274         if let Some(lock) = self.as_mut().as_locking() {
275             let guard = ready!(lock.poll(cx));
276             self.set(StdioHandle::Locked(guard));
277         }
278 
279         let mut guard = match self.as_mut().take_guard() {
280             Some(guard) => guard,
281             // If the guard can't be acquired that means that this stream is
282             // closed, so return that we're ready without filling in data.
283             None => return Poll::Ready(None),
284         };
285 
286         // Wait for our locked stream to be ready, resetting to the "locked"
287         // state if it's not quite ready yet.
288         match guard.poll_ready(cx) {
289             Poll::Ready(()) => {}
290 
291             // If the read isn't ready yet then restore our "locked" state
292             // since we haven't finished, then return pending.
293             Poll::Pending => {
294                 self.set(StdioHandle::Locked(guard));
295                 return Poll::Pending;
296             }
297         }
298 
299         // Perform the I/O and delegate on the result.
300         match op(&mut guard) {
301             // The I/O succeeded so relinquish the lock on this stream by
302             // transitioning back to the "Ready" state.
303             Ok(result) => {
304                 self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone()));
305                 Poll::Ready(Some(Ok(result)))
306             }
307 
308             // The stream is closed, and `take_guard` above already set the
309             // closed state, so return nothing indicating the closure.
310             Err(p2::StreamError::Closed) => Poll::Ready(None),
311 
312             // The stream failed so propagate the error. Errors should only
313             // come from the underlying I/O object and thus should cast
314             // successfully. Additionally `take_guard` replaced our state
315             // with "closed" above which is the desired state at this point.
316             Err(p2::StreamError::LastOperationFailed(e)) => {
317                 Poll::Ready(Some(Err(e.downcast().unwrap())))
318             }
319 
320             // Shouldn't be possible to produce a trap here.
321             Err(p2::StreamError::Trap(_)) => unreachable!(),
322         }
323     }
324 
as_locking( self: Pin<&mut Self>, ) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>>325     fn as_locking(
326         self: Pin<&mut Self>,
327     ) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> {
328         // SAFETY: this is a pin-projection from `self` into the `Locking`
329         // field.
330         unsafe {
331             match self.get_unchecked_mut() {
332                 StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)),
333                 _ => None,
334             }
335         }
336     }
337 
take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>>338     fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> {
339         if !matches!(*self, StdioHandle::Locked(_)) {
340             return None;
341         }
342         // SAFETY: the `Locked` arm is safe to move as it's an invariant of this
343         // type that it's not pinned.
344         unsafe {
345             match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) {
346                 StdioHandle::Locked(guard) => Some(guard),
347                 _ => unreachable!(),
348             }
349         }
350     }
351 }
352