1 //! Handling for standard in using a worker task.
2 //!
3 //! Standard input is a global singleton resource for the entire program which
4 //! needs special care. Currently this implementation adheres to a few
5 //! constraints which make this nontrivial to implement.
6 //!
7 //! * Any number of guest wasm programs can read stdin. While this doesn't make
8 //!   a ton of sense semantically they shouldn't block forever. Instead it's a
9 //!   race to see who actually reads which parts of stdin.
10 //!
11 //! * Data from stdin isn't actually read unless requested. This is done to try
12 //!   to be a good neighbor to others running in the process. Under the
13 //!   assumption that most programs have one "thing" which reads stdin the
14 //!   actual consumption of bytes is delayed until the wasm guest is dynamically
15 //!   chosen to be that "thing". Before that data from stdin is not consumed to
16 //!   avoid taking it from other components in the process.
17 //!
18 //! * Tokio's documentation indicates that "interactive stdin" is best done with
19 //!   a helper thread to avoid blocking shutdown of the event loop. That's
20 //!   respected here where all stdin reading happens on a blocking helper thread
21 //!   that, at this time, is never shut down.
22 //!
23 //! This module is one that's likely to change over time though as new systems
24 //! are encountered along with preexisting bugs.
25 
26 use crate::cli::{IsTerminal, StdinStream};
27 use bytes::{Bytes, BytesMut};
28 use std::io::Read;
29 use std::mem;
30 use std::pin::Pin;
31 use std::sync::{Condvar, Mutex, OnceLock};
32 use std::task::{Context, Poll};
33 use tokio::io::{self, AsyncRead, ReadBuf};
34 use tokio::sync::Notify;
35 use tokio::sync::futures::Notified;
36 use wasmtime_wasi_io::{
37     poll::Pollable,
38     streams::{InputStream, StreamError},
39 };
40 
41 // Implementation for tokio::io::Stdin
42 impl IsTerminal for tokio::io::Stdin {
43     fn is_terminal(&self) -> bool {
44         std::io::stdin().is_terminal()
45     }
46 }
47 impl StdinStream for tokio::io::Stdin {
48     fn p2_stream(&self) -> Box<dyn InputStream> {
49         Box::new(WasiStdin)
50     }
51     fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
52         Box::new(WasiStdinAsyncRead::Ready)
53     }
54 }
55 
56 // Implementation for std::io::Stdin
57 impl IsTerminal for std::io::Stdin {
58     fn is_terminal(&self) -> bool {
59         std::io::IsTerminal::is_terminal(self)
60     }
61 }
62 impl StdinStream for std::io::Stdin {
63     fn p2_stream(&self) -> Box<dyn InputStream> {
64         Box::new(WasiStdin)
65     }
66     fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
67         Box::new(WasiStdinAsyncRead::Ready)
68     }
69 }
70 
71 #[derive(Default)]
72 struct GlobalStdin {
73     state: Mutex<StdinState>,
74     read_requested: Condvar,
75     read_completed: Notify,
76 }
77 
78 #[derive(Default, Debug)]
79 enum StdinState {
80     #[default]
81     ReadNotRequested,
82     ReadRequested,
83     Data(BytesMut),
84     Error(std::io::Error),
85     Closed,
86 }
87 
88 impl GlobalStdin {
89     fn get() -> &'static GlobalStdin {
90         static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
91         STDIN.get_or_init(|| create())
92     }
93 }
94 
95 fn create() -> GlobalStdin {
96     std::thread::spawn(|| {
97         let state = GlobalStdin::get();
98         loop {
99             // Wait for a read to be requested, but don't hold the lock across
100             // the blocking read.
101             let mut lock = state.state.lock().unwrap();
102             lock = state
103                 .read_requested
104                 .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
105                 .unwrap();
106             drop(lock);
107 
108             let mut bytes = BytesMut::zeroed(1024);
109             let (new_state, done) = match std::io::stdin().read(&mut bytes) {
110                 Ok(0) => (StdinState::Closed, true),
111                 Ok(nbytes) => {
112                     bytes.truncate(nbytes);
113                     (StdinState::Data(bytes), false)
114                 }
115                 Err(e) => (StdinState::Error(e), true),
116             };
117 
118             // After the blocking read completes the state should not have been
119             // tampered with.
120             debug_assert!(matches!(
121                 *state.state.lock().unwrap(),
122                 StdinState::ReadRequested
123             ));
124             let mut lock = state.state.lock().unwrap();
125             *lock = new_state;
126             state.read_completed.notify_waiters();
127             if done {
128                 break;
129             }
130         }
131     });
132 
133     GlobalStdin::default()
134 }
135 
136 struct WasiStdin;
137 
138 #[async_trait::async_trait]
139 impl InputStream for WasiStdin {
140     fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
141         let g = GlobalStdin::get();
142         let mut locked = g.state.lock().unwrap();
143         match mem::replace(&mut *locked, StdinState::ReadRequested) {
144             StdinState::ReadNotRequested => {
145                 g.read_requested.notify_one();
146                 Ok(Bytes::new())
147             }
148             StdinState::ReadRequested => Ok(Bytes::new()),
149             StdinState::Data(mut data) => {
150                 let size = data.len().min(size);
151                 let bytes = data.split_to(size);
152                 *locked = if data.is_empty() {
153                     StdinState::ReadNotRequested
154                 } else {
155                     StdinState::Data(data)
156                 };
157                 Ok(bytes.freeze())
158             }
159             StdinState::Error(e) => {
160                 *locked = StdinState::Closed;
161                 Err(StreamError::LastOperationFailed(e.into()))
162             }
163             StdinState::Closed => {
164                 *locked = StdinState::Closed;
165                 Err(StreamError::Closed)
166             }
167         }
168     }
169 }
170 
171 #[async_trait::async_trait]
172 impl Pollable for WasiStdin {
173     async fn ready(&mut self) {
174         let g = GlobalStdin::get();
175 
176         // Scope the synchronous `state.lock()` to this block which does not
177         // `.await` inside of it.
178         let notified = {
179             let mut locked = g.state.lock().unwrap();
180             match *locked {
181                 // If a read isn't requested yet
182                 StdinState::ReadNotRequested => {
183                     g.read_requested.notify_one();
184                     *locked = StdinState::ReadRequested;
185                     g.read_completed.notified()
186                 }
187                 StdinState::ReadRequested => g.read_completed.notified(),
188                 StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
189             }
190         };
191 
192         notified.await;
193     }
194 }
195 
196 enum WasiStdinAsyncRead {
197     Ready,
198     Waiting(Notified<'static>),
199 }
200 
201 impl AsyncRead for WasiStdinAsyncRead {
202     fn poll_read(
203         mut self: Pin<&mut Self>,
204         cx: &mut Context<'_>,
205         buf: &mut ReadBuf<'_>,
206     ) -> Poll<io::Result<()>> {
207         let g = GlobalStdin::get();
208 
209         // Everything below is executed under the global stdin lock. It's not
210         // going to block below so that's semantically fine. Optimization-wise
211         // it's probably possible to move this within the loop around just a
212         // small part of reading/writing the state, but that was done
213         // historically and it resulted in lost wakeups with `Notify`, so this
214         // is conservatively hoisted up here.
215         let mut locked = g.state.lock().unwrap();
216 
217         // Perform everything below in a `loop` to handle the case that a read
218         // was stolen by another thread, for example, or perhaps a spurious
219         // notification to `Notified`.
220         loop {
221             // If we were previously blocked on reading a "ready" notification,
222             // wait for that notification to complete.
223             if let Some(notified) = self.as_mut().notified_future() {
224                 match notified.poll(cx) {
225                     Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
226                     Poll::Pending => break Poll::Pending,
227                 }
228             }
229 
230             assert!(matches!(*self, WasiStdinAsyncRead::Ready));
231 
232             // Once we're in the "ready" state then take a look at the global
233             // state of stdin.
234             match mem::replace(&mut *locked, StdinState::ReadRequested) {
235                 // If data is available then drain what we can into `buf`.
236                 StdinState::Data(mut data) => {
237                     let size = data.len().min(buf.remaining());
238                     let bytes = data.split_to(size);
239                     *locked = if data.is_empty() {
240                         StdinState::ReadNotRequested
241                     } else {
242                         StdinState::Data(data)
243                     };
244                     buf.put_slice(&bytes);
245                     break Poll::Ready(Ok(()));
246                 }
247 
248                 // If stdin failed to be read then we fail with that error and
249                 // transition to "closed"
250                 StdinState::Error(e) => {
251                     *locked = StdinState::Closed;
252                     break Poll::Ready(Err(e));
253                 }
254 
255                 // If stdin is closed, keep it closed.
256                 StdinState::Closed => {
257                     *locked = StdinState::Closed;
258                     break Poll::Ready(Ok(()));
259                 }
260 
261                 // For these states we indicate that a read is requested, if it
262                 // wasn't previously requested, and then we transition to
263                 // `Waiting` below by falling through outside this `match`.
264                 StdinState::ReadNotRequested => {
265                     g.read_requested.notify_one();
266                 }
267                 StdinState::ReadRequested => {}
268             }
269 
270             self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
271         }
272     }
273 }
274 
275 impl WasiStdinAsyncRead {
276     fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
277         // SAFETY: this is a pin-projection from `self` to the field `Notified`
278         // internally. Given that `self` is pinned it should be safe to acquire
279         // a pinned version of the internal field.
280         unsafe {
281             match self.get_unchecked_mut() {
282                 WasiStdinAsyncRead::Ready => None,
283                 WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
284             }
285         }
286     }
287 }
288