1 use crate::{
2     Error, ErrorExt,
3     file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
4 };
5 #[cfg(windows)]
6 use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
7 use io_lifetimes::AsSocketlike;
8 #[cfg(unix)]
9 use io_lifetimes::{AsFd, BorrowedFd};
10 #[cfg(windows)]
11 use io_lifetimes::{AsSocket, BorrowedSocket};
12 use std::any::Any;
13 use std::io;
14 #[cfg(unix)]
15 use system_interface::fs::GetSetFdFlags;
16 use system_interface::io::IoExt;
17 use system_interface::io::IsReadWrite;
18 use system_interface::io::ReadReady;
19 
20 pub enum Socket {
21     TcpListener(cap_std::net::TcpListener),
22     TcpStream(cap_std::net::TcpStream),
23     #[cfg(unix)]
24     UnixStream(cap_std::os::unix::net::UnixStream),
25     #[cfg(unix)]
26     UnixListener(cap_std::os::unix::net::UnixListener),
27 }
28 
29 impl From<cap_std::net::TcpListener> for Socket {
from(listener: cap_std::net::TcpListener) -> Self30     fn from(listener: cap_std::net::TcpListener) -> Self {
31         Self::TcpListener(listener)
32     }
33 }
34 
35 impl From<cap_std::net::TcpStream> for Socket {
from(stream: cap_std::net::TcpStream) -> Self36     fn from(stream: cap_std::net::TcpStream) -> Self {
37         Self::TcpStream(stream)
38     }
39 }
40 
41 #[cfg(unix)]
42 impl From<cap_std::os::unix::net::UnixListener> for Socket {
from(listener: cap_std::os::unix::net::UnixListener) -> Self43     fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
44         Self::UnixListener(listener)
45     }
46 }
47 
48 #[cfg(unix)]
49 impl From<cap_std::os::unix::net::UnixStream> for Socket {
from(stream: cap_std::os::unix::net::UnixStream) -> Self50     fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
51         Self::UnixStream(stream)
52     }
53 }
54 
55 #[cfg(unix)]
56 impl From<Socket> for Box<dyn WasiFile> {
from(listener: Socket) -> Self57     fn from(listener: Socket) -> Self {
58         match listener {
59             Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
60             Socket::UnixListener(l) => Box::new(crate::sync::net::UnixListener::from_cap_std(l)),
61             Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
62             Socket::UnixStream(l) => Box::new(crate::sync::net::UnixStream::from_cap_std(l)),
63         }
64     }
65 }
66 
67 #[cfg(windows)]
68 impl From<Socket> for Box<dyn WasiFile> {
from(listener: Socket) -> Self69     fn from(listener: Socket) -> Self {
70         match listener {
71             Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
72             Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
73         }
74     }
75 }
76 
77 macro_rules! wasi_listen_write_impl {
78     ($ty:ty, $stream:ty) => {
79         #[async_trait::async_trait]
80         impl WasiFile for $ty {
81             fn as_any(&self) -> &dyn Any {
82                 self
83             }
84             #[cfg(unix)]
85             fn pollable(&self) -> Option<rustix::fd::BorrowedFd<'_>> {
86                 Some(self.0.as_fd())
87             }
88             #[cfg(windows)]
89             fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
90                 Some(self.0.as_raw_handle_or_socket())
91             }
92             async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
93                 let (stream, _) = self.0.accept()?;
94                 let mut stream = <$stream>::from_cap_std(stream);
95                 stream.set_fdflags(fdflags).await?;
96                 Ok(Box::new(stream))
97             }
98             async fn get_filetype(&self) -> Result<FileType, Error> {
99                 Ok(FileType::SocketStream)
100             }
101             #[cfg(unix)]
102             async fn get_fdflags(&self) -> Result<FdFlags, Error> {
103                 let fdflags = get_fd_flags(&self.0)?;
104                 Ok(fdflags)
105             }
106             async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
107                 if fdflags == crate::file::FdFlags::NONBLOCK {
108                     self.0.set_nonblocking(true)?;
109                 } else if fdflags.is_empty() {
110                     self.0.set_nonblocking(false)?;
111                 } else {
112                     return Err(
113                         Error::invalid_argument().context("cannot set anything else than NONBLOCK")
114                     );
115                 }
116                 Ok(())
117             }
118             fn num_ready_bytes(&self) -> Result<u64, Error> {
119                 Ok(1)
120             }
121         }
122 
123         #[cfg(windows)]
124         impl AsSocket for $ty {
125             #[inline]
126             fn as_socket(&self) -> BorrowedSocket<'_> {
127                 self.0.as_socket()
128             }
129         }
130 
131         #[cfg(windows)]
132         impl AsRawHandleOrSocket for $ty {
133             #[inline]
134             fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
135                 self.0.as_raw_handle_or_socket()
136             }
137         }
138 
139         #[cfg(unix)]
140         impl AsFd for $ty {
141             fn as_fd(&self) -> BorrowedFd<'_> {
142                 self.0.as_fd()
143             }
144         }
145     };
146 }
147 
148 pub struct TcpListener(cap_std::net::TcpListener);
149 
150 impl TcpListener {
from_cap_std(cap_std: cap_std::net::TcpListener) -> Self151     pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
152         TcpListener(cap_std)
153     }
154 }
155 wasi_listen_write_impl!(TcpListener, TcpStream);
156 
157 #[cfg(unix)]
158 pub struct UnixListener(cap_std::os::unix::net::UnixListener);
159 
160 #[cfg(unix)]
161 impl UnixListener {
from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self162     pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
163         UnixListener(cap_std)
164     }
165 }
166 
167 #[cfg(unix)]
168 wasi_listen_write_impl!(UnixListener, UnixStream);
169 
170 macro_rules! wasi_stream_write_impl {
171     ($ty:ty, $std_ty:ty) => {
172         #[async_trait::async_trait]
173         impl WasiFile for $ty {
174             fn as_any(&self) -> &dyn Any {
175                 self
176             }
177             #[cfg(unix)]
178             fn pollable(&self) -> Option<rustix::fd::BorrowedFd<'_>> {
179                 Some(self.0.as_fd())
180             }
181             #[cfg(windows)]
182             fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
183                 Some(self.0.as_raw_handle_or_socket())
184             }
185             async fn get_filetype(&self) -> Result<FileType, Error> {
186                 Ok(FileType::SocketStream)
187             }
188             #[cfg(unix)]
189             async fn get_fdflags(&self) -> Result<FdFlags, Error> {
190                 let fdflags = get_fd_flags(&self.0)?;
191                 Ok(fdflags)
192             }
193             async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
194                 if fdflags == crate::file::FdFlags::NONBLOCK {
195                     self.0.set_nonblocking(true)?;
196                 } else if fdflags.is_empty() {
197                     self.0.set_nonblocking(false)?;
198                 } else {
199                     return Err(
200                         Error::invalid_argument().context("cannot set anything else than NONBLOCK")
201                     );
202                 }
203                 Ok(())
204             }
205             async fn read_vectored<'a>(
206                 &self,
207                 bufs: &mut [io::IoSliceMut<'a>],
208             ) -> Result<u64, Error> {
209                 use std::io::Read;
210                 let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
211                 Ok(n.try_into()?)
212             }
213             async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
214                 use std::io::Write;
215                 let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
216                 Ok(n.try_into()?)
217             }
218             async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
219                 let n = self.0.peek(buf)?;
220                 Ok(n.try_into()?)
221             }
222             fn num_ready_bytes(&self) -> Result<u64, Error> {
223                 let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
224                 Ok(val)
225             }
226             async fn readable(&self) -> Result<(), Error> {
227                 let (readable, _writeable) = is_read_write(&self.0)?;
228                 if readable { Ok(()) } else { Err(Error::io()) }
229             }
230             async fn writable(&self) -> Result<(), Error> {
231                 let (_readable, writeable) = is_read_write(&self.0)?;
232                 if writeable { Ok(()) } else { Err(Error::io()) }
233             }
234 
235             async fn sock_recv<'a>(
236                 &self,
237                 ri_data: &mut [std::io::IoSliceMut<'a>],
238                 ri_flags: RiFlags,
239             ) -> Result<(u64, RoFlags), Error> {
240                 if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
241                     return Err(Error::not_supported());
242                 }
243 
244                 if ri_flags.contains(RiFlags::RECV_PEEK) {
245                     if let Some(first) = ri_data.iter_mut().next() {
246                         let n = self.0.peek(first)?;
247                         return Ok((n as u64, RoFlags::empty()));
248                     } else {
249                         return Ok((0, RoFlags::empty()));
250                     }
251                 }
252 
253                 if ri_flags.contains(RiFlags::RECV_WAITALL) {
254                     let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
255                     self.0.read_exact_vectored(ri_data)?;
256                     return Ok((n as u64, RoFlags::empty()));
257                 }
258 
259                 let n = self.0.read_vectored(ri_data)?;
260                 Ok((n as u64, RoFlags::empty()))
261             }
262 
263             async fn sock_send<'a>(
264                 &self,
265                 si_data: &[std::io::IoSlice<'a>],
266                 si_flags: SiFlags,
267             ) -> Result<u64, Error> {
268                 if si_flags != SiFlags::empty() {
269                     return Err(Error::not_supported());
270                 }
271 
272                 let n = self.0.write_vectored(si_data)?;
273                 Ok(n as u64)
274             }
275 
276             async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
277                 let how = if how == SdFlags::RD | SdFlags::WR {
278                     cap_std::net::Shutdown::Both
279                 } else if how == SdFlags::RD {
280                     cap_std::net::Shutdown::Read
281                 } else if how == SdFlags::WR {
282                     cap_std::net::Shutdown::Write
283                 } else {
284                     return Err(Error::invalid_argument());
285                 };
286                 self.0.shutdown(how)?;
287                 Ok(())
288             }
289         }
290         #[cfg(unix)]
291         impl AsFd for $ty {
292             fn as_fd(&self) -> BorrowedFd<'_> {
293                 self.0.as_fd()
294             }
295         }
296 
297         #[cfg(windows)]
298         impl AsSocket for $ty {
299             /// Borrows the socket.
300             fn as_socket(&self) -> BorrowedSocket<'_> {
301                 self.0.as_socket()
302             }
303         }
304 
305         #[cfg(windows)]
306         impl AsRawHandleOrSocket for TcpStream {
307             #[inline]
308             fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
309                 self.0.as_raw_handle_or_socket()
310             }
311         }
312     };
313 }
314 
315 pub struct TcpStream(cap_std::net::TcpStream);
316 
317 impl TcpStream {
from_cap_std(socket: cap_std::net::TcpStream) -> Self318     pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
319         TcpStream(socket)
320     }
321 }
322 
323 wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
324 
325 #[cfg(unix)]
326 pub struct UnixStream(cap_std::os::unix::net::UnixStream);
327 
328 #[cfg(unix)]
329 impl UnixStream {
from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self330     pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
331         UnixStream(socket)
332     }
333 }
334 
335 #[cfg(unix)]
336 wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
337 
filetype_from(ft: &cap_std::fs::FileType) -> FileType338 pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
339     use cap_fs_ext::FileTypeExt;
340     if ft.is_block_device() {
341         FileType::SocketDgram
342     } else {
343         FileType::SocketStream
344     }
345 }
346 
347 /// Return the file-descriptor flags for a given file-like object.
348 ///
349 /// This returns the flags needed to implement [`WasiFile::get_fdflags`].
get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags>350 pub fn get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags> {
351     // On Unix-family platforms, we can use the same system call that we'd use
352     // for files on sockets here.
353     #[cfg(not(windows))]
354     {
355         let mut out = crate::file::FdFlags::empty();
356         if f.get_fd_flags()?
357             .contains(system_interface::fs::FdFlags::NONBLOCK)
358         {
359             out |= crate::file::FdFlags::NONBLOCK;
360         }
361         Ok(out)
362     }
363 
364     // On Windows, sockets are different, and there is no direct way to
365     // query for the non-blocking flag. We can get a sufficient approximation
366     // by testing whether a zero-length `recv` appears to block.
367     #[cfg(windows)]
368     let buf: &mut [u8] = &mut [];
369     #[cfg(windows)]
370     match rustix::net::recv(f, buf, rustix::net::RecvFlags::empty()) {
371         Ok(_) => Ok(crate::file::FdFlags::empty()),
372         Err(rustix::io::Errno::WOULDBLOCK) => Ok(crate::file::FdFlags::NONBLOCK),
373         Err(e) => Err(e.into()),
374     }
375 }
376 
377 /// Return the file-descriptor flags for a given file-like object.
378 ///
379 /// This returns the flags needed to implement [`WasiFile::get_fdflags`].
is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)>380 pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
381     // On Unix-family platforms, we have an `IsReadWrite` impl.
382     #[cfg(not(windows))]
383     {
384         f.is_read_write()
385     }
386 
387     // On Windows, we only have a `TcpStream` impl, so make a view first.
388     #[cfg(windows)]
389     {
390         f.as_socketlike_view::<std::net::TcpStream>()
391             .is_read_write()
392     }
393 }
394