1 use crate::wasi::clocks::monotonic_clock;
2 use crate::wasi::io::poll::{self, Pollable};
3 use crate::wasi::io::streams::{InputStream, OutputStream, StreamError};
4 use crate::wasi::random;
5 use crate::wasi::sockets::instance_network;
6 use crate::wasi::sockets::ip_name_lookup;
7 use crate::wasi::sockets::network::{
8     ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress,
9     Network,
10 };
11 use crate::wasi::sockets::tcp::TcpSocket;
12 use crate::wasi::sockets::udp::{
13     IncomingDatagram, IncomingDatagramStream, OutgoingDatagram, OutgoingDatagramStream, UdpSocket,
14 };
15 use crate::wasi::sockets::{tcp_create_socket, udp_create_socket};
16 use std::ops::Range;
17 
18 const TIMEOUT_NS: u64 = 1_000_000_000;
19 
supports_ipv6() -> bool20 pub fn supports_ipv6() -> bool {
21     std::env::var("DISABLE_IPV6").is_err()
22 }
23 
24 impl Pollable {
block_until(&self, timeout: &Pollable) -> Result<(), ErrorCode>25     pub fn block_until(&self, timeout: &Pollable) -> Result<(), ErrorCode> {
26         let ready = poll::poll(&[self, timeout]);
27         assert!(ready.len() > 0);
28         match ready[0] {
29             0 => Ok(()),
30             1 => Err(ErrorCode::Timeout),
31             _ => unreachable!(),
32         }
33     }
34 }
35 
36 impl InputStream {
blocking_read_to_end(&self) -> Result<Vec<u8>, crate::wasi::io::error::Error>37     pub fn blocking_read_to_end(&self) -> Result<Vec<u8>, crate::wasi::io::error::Error> {
38         let mut data = vec![];
39         loop {
40             match self.blocking_read(1024 * 1024) {
41                 Ok(chunk) => data.extend(chunk),
42                 Err(StreamError::Closed) => return Ok(data),
43                 Err(StreamError::LastOperationFailed(e)) => return Err(e),
44             }
45         }
46     }
47 }
48 
49 impl OutputStream {
blocking_write_util(&self, mut bytes: &[u8]) -> Result<(), StreamError>50     pub fn blocking_write_util(&self, mut bytes: &[u8]) -> Result<(), StreamError> {
51         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
52         let pollable = self.subscribe();
53 
54         while !bytes.is_empty() {
55             pollable.block_until(&timeout).expect("write timed out");
56 
57             let permit = self.check_write()?;
58 
59             let len = bytes.len().min(permit as usize);
60             let (chunk, rest) = bytes.split_at(len);
61 
62             self.write(chunk)?;
63 
64             self.blocking_flush()?;
65 
66             bytes = rest;
67         }
68         Ok(())
69     }
70 }
71 
72 impl Network {
default() -> Network73     pub fn default() -> Network {
74         instance_network::instance_network()
75     }
76 
blocking_resolve_addresses(&self, name: &str) -> Result<Vec<IpAddress>, ErrorCode>77     pub fn blocking_resolve_addresses(&self, name: &str) -> Result<Vec<IpAddress>, ErrorCode> {
78         let stream = ip_name_lookup::resolve_addresses(&self, name)?;
79 
80         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
81         let pollable = stream.subscribe();
82 
83         let mut addresses = vec![];
84 
85         loop {
86             match stream.resolve_next_address() {
87                 Ok(Some(addr)) => {
88                     addresses.push(addr);
89                 }
90                 Ok(None) => match addresses[..] {
91                     [] => return Err(ErrorCode::NameUnresolvable),
92                     _ => return Ok(addresses),
93                 },
94                 Err(ErrorCode::WouldBlock) => {
95                     pollable.block_until(&timeout)?;
96                 }
97                 Err(err) => return Err(err),
98             }
99         }
100     }
101 
102     /// Same as `Network::blocking_resolve_addresses` but ignores post validation errors
103     ///
104     /// The ignored error codes signal that the input passed validation
105     /// and a lookup was actually attempted, but failed. These are ignored to
106     /// make the CI tests less flaky.
permissive_blocking_resolve_addresses( &self, name: &str, ) -> Result<Vec<IpAddress>, ErrorCode>107     pub fn permissive_blocking_resolve_addresses(
108         &self,
109         name: &str,
110     ) -> Result<Vec<IpAddress>, ErrorCode> {
111         match self.blocking_resolve_addresses(name) {
112             Err(ErrorCode::NameUnresolvable | ErrorCode::TemporaryResolverFailure) => Ok(vec![]),
113             r => r,
114         }
115     }
116 }
117 
118 impl TcpSocket {
new(address_family: IpAddressFamily) -> Result<TcpSocket, ErrorCode>119     pub fn new(address_family: IpAddressFamily) -> Result<TcpSocket, ErrorCode> {
120         tcp_create_socket::create_tcp_socket(address_family)
121     }
122 
blocking_bind( &self, network: &Network, local_address: IpSocketAddress, ) -> Result<(), ErrorCode>123     pub fn blocking_bind(
124         &self,
125         network: &Network,
126         local_address: IpSocketAddress,
127     ) -> Result<(), ErrorCode> {
128         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
129         let sub = self.subscribe();
130 
131         self.start_bind(&network, local_address)?;
132 
133         loop {
134             match self.finish_bind() {
135                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
136                 result => return result,
137             }
138         }
139     }
140 
blocking_listen(&self) -> Result<(), ErrorCode>141     pub fn blocking_listen(&self) -> Result<(), ErrorCode> {
142         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
143         let sub = self.subscribe();
144 
145         self.start_listen()?;
146 
147         loop {
148             match self.finish_listen() {
149                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
150                 result => return result,
151             }
152         }
153     }
154 
blocking_connect( &self, network: &Network, remote_address: IpSocketAddress, ) -> Result<(InputStream, OutputStream), ErrorCode>155     pub fn blocking_connect(
156         &self,
157         network: &Network,
158         remote_address: IpSocketAddress,
159     ) -> Result<(InputStream, OutputStream), ErrorCode> {
160         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
161         let sub = self.subscribe();
162 
163         self.start_connect(&network, remote_address)?;
164 
165         loop {
166             match self.finish_connect() {
167                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
168                 result => return result,
169             }
170         }
171     }
172 
blocking_accept(&self) -> Result<(TcpSocket, InputStream, OutputStream), ErrorCode>173     pub fn blocking_accept(&self) -> Result<(TcpSocket, InputStream, OutputStream), ErrorCode> {
174         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
175         let sub = self.subscribe();
176 
177         loop {
178             match self.accept() {
179                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
180                 result => return result,
181             }
182         }
183     }
184 }
185 
186 impl UdpSocket {
new(address_family: IpAddressFamily) -> Result<UdpSocket, ErrorCode>187     pub fn new(address_family: IpAddressFamily) -> Result<UdpSocket, ErrorCode> {
188         udp_create_socket::create_udp_socket(address_family)
189     }
190 
blocking_bind( &self, network: &Network, local_address: IpSocketAddress, ) -> Result<(), ErrorCode>191     pub fn blocking_bind(
192         &self,
193         network: &Network,
194         local_address: IpSocketAddress,
195     ) -> Result<(), ErrorCode> {
196         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
197         let sub = self.subscribe();
198 
199         self.start_bind(&network, local_address)?;
200 
201         loop {
202             match self.finish_bind() {
203                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
204                 result => return result,
205             }
206         }
207     }
208 
blocking_bind_unspecified(&self, network: &Network) -> Result<(), ErrorCode>209     pub fn blocking_bind_unspecified(&self, network: &Network) -> Result<(), ErrorCode> {
210         let ip = IpAddress::new_unspecified(self.address_family());
211         let port = 0;
212 
213         self.blocking_bind(network, IpSocketAddress::new(ip, port))
214     }
215 }
216 
217 impl OutgoingDatagramStream {
blocking_check_send(&self, timeout: &Pollable) -> Result<u64, ErrorCode>218     fn blocking_check_send(&self, timeout: &Pollable) -> Result<u64, ErrorCode> {
219         let sub = self.subscribe();
220 
221         loop {
222             match self.check_send() {
223                 Ok(0) => sub.block_until(timeout)?,
224                 result => return result,
225             }
226         }
227     }
228 
blocking_send(&self, mut datagrams: &[OutgoingDatagram]) -> Result<(), ErrorCode>229     pub fn blocking_send(&self, mut datagrams: &[OutgoingDatagram]) -> Result<(), ErrorCode> {
230         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
231 
232         while !datagrams.is_empty() {
233             let permit = self.blocking_check_send(&timeout)?;
234             let chunk_len = datagrams.len().min(permit as usize);
235             match self.send(&datagrams[..chunk_len]) {
236                 Ok(0) => {}
237                 Ok(packets_sent) => {
238                     let packets_sent = packets_sent as usize;
239                     datagrams = &datagrams[packets_sent..];
240                 }
241                 Err(err) => return Err(err),
242             }
243         }
244 
245         Ok(())
246     }
247 }
248 
249 impl IncomingDatagramStream {
blocking_receive(&self, count: Range<u64>) -> Result<Vec<IncomingDatagram>, ErrorCode>250     pub fn blocking_receive(&self, count: Range<u64>) -> Result<Vec<IncomingDatagram>, ErrorCode> {
251         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
252         let pollable = self.subscribe();
253         let mut datagrams = vec![];
254 
255         loop {
256             match self.receive(count.end - datagrams.len() as u64) {
257                 Ok(mut chunk) => {
258                     datagrams.append(&mut chunk);
259 
260                     if datagrams.len() >= count.start as usize {
261                         return Ok(datagrams);
262                     } else {
263                         pollable.block_until(&timeout)?;
264                     }
265                 }
266                 Err(err) => return Err(err),
267             }
268         }
269     }
270 }
271 
272 impl IpAddress {
273     pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255));
274 
275     pub const IPV4_LOOPBACK: IpAddress = IpAddress::Ipv4((127, 0, 0, 1));
276     pub const IPV6_LOOPBACK: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 1));
277 
278     pub const IPV4_UNSPECIFIED: IpAddress = IpAddress::Ipv4((0, 0, 0, 0));
279     pub const IPV6_UNSPECIFIED: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 0));
280 
281     pub const IPV4_MAPPED_LOOPBACK: IpAddress =
282         IpAddress::Ipv6((0, 0, 0, 0, 0, 0xFFFF, 0x7F00, 0x0001));
283 
new_loopback(family: IpAddressFamily) -> IpAddress284     pub const fn new_loopback(family: IpAddressFamily) -> IpAddress {
285         match family {
286             IpAddressFamily::Ipv4 => Self::IPV4_LOOPBACK,
287             IpAddressFamily::Ipv6 => Self::IPV6_LOOPBACK,
288         }
289     }
290 
new_unspecified(family: IpAddressFamily) -> IpAddress291     pub const fn new_unspecified(family: IpAddressFamily) -> IpAddress {
292         match family {
293             IpAddressFamily::Ipv4 => Self::IPV4_UNSPECIFIED,
294             IpAddressFamily::Ipv6 => Self::IPV6_UNSPECIFIED,
295         }
296     }
297 
family(&self) -> IpAddressFamily298     pub const fn family(&self) -> IpAddressFamily {
299         match self {
300             IpAddress::Ipv4(_) => IpAddressFamily::Ipv4,
301             IpAddress::Ipv6(_) => IpAddressFamily::Ipv6,
302         }
303     }
304 }
305 
306 impl PartialEq for IpAddress {
eq(&self, other: &Self) -> bool307     fn eq(&self, other: &Self) -> bool {
308         match (self, other) {
309             (Self::Ipv4(left), Self::Ipv4(right)) => left == right,
310             (Self::Ipv6(left), Self::Ipv6(right)) => left == right,
311             _ => false,
312         }
313     }
314 }
315 
316 impl IpSocketAddress {
new(ip: IpAddress, port: u16) -> IpSocketAddress317     pub const fn new(ip: IpAddress, port: u16) -> IpSocketAddress {
318         match ip {
319             IpAddress::Ipv4(addr) => IpSocketAddress::Ipv4(Ipv4SocketAddress {
320                 port,
321                 address: addr,
322             }),
323             IpAddress::Ipv6(addr) => IpSocketAddress::Ipv6(Ipv6SocketAddress {
324                 port,
325                 address: addr,
326                 flow_info: 0,
327                 scope_id: 0,
328             }),
329         }
330     }
331 
ip(&self) -> IpAddress332     pub const fn ip(&self) -> IpAddress {
333         match self {
334             IpSocketAddress::Ipv4(addr) => IpAddress::Ipv4(addr.address),
335             IpSocketAddress::Ipv6(addr) => IpAddress::Ipv6(addr.address),
336         }
337     }
338 
port(&self) -> u16339     pub const fn port(&self) -> u16 {
340         match self {
341             IpSocketAddress::Ipv4(addr) => addr.port,
342             IpSocketAddress::Ipv6(addr) => addr.port,
343         }
344     }
345 
family(&self) -> IpAddressFamily346     pub const fn family(&self) -> IpAddressFamily {
347         match self {
348             IpSocketAddress::Ipv4(_) => IpAddressFamily::Ipv4,
349             IpSocketAddress::Ipv6(_) => IpAddressFamily::Ipv6,
350         }
351     }
352 }
353 
354 impl PartialEq for Ipv4SocketAddress {
eq(&self, other: &Self) -> bool355     fn eq(&self, other: &Self) -> bool {
356         self.port == other.port && self.address == other.address
357     }
358 }
359 
360 impl PartialEq for Ipv6SocketAddress {
eq(&self, other: &Self) -> bool361     fn eq(&self, other: &Self) -> bool {
362         self.port == other.port
363             && self.flow_info == other.flow_info
364             && self.address == other.address
365             && self.scope_id == other.scope_id
366     }
367 }
368 
369 impl PartialEq for IpSocketAddress {
eq(&self, other: &Self) -> bool370     fn eq(&self, other: &Self) -> bool {
371         match (self, other) {
372             (Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0,
373             (Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0,
374             _ => false,
375         }
376     }
377 }
378 
generate_random_u16(range: Range<u16>) -> u16379 fn generate_random_u16(range: Range<u16>) -> u16 {
380     let start = range.start as u64;
381     let end = range.end as u64;
382     let port = start + (random::random::get_random_u64() % (end - start));
383     port as u16
384 }
385 
386 /// Execute the inner function with a randomly generated port.
387 /// To prevent random failures, we make a few attempts before giving up.
attempt_random_port<F>( local_address: IpAddress, mut f: F, ) -> Result<IpSocketAddress, ErrorCode> where F: FnMut(IpSocketAddress) -> Result<(), ErrorCode>,388 pub fn attempt_random_port<F>(
389     local_address: IpAddress,
390     mut f: F,
391 ) -> Result<IpSocketAddress, ErrorCode>
392 where
393     F: FnMut(IpSocketAddress) -> Result<(), ErrorCode>,
394 {
395     const MAX_ATTEMPTS: u32 = 10;
396     let mut i = 0;
397     loop {
398         i += 1;
399 
400         let port: u16 = generate_random_u16(1024..u16::MAX);
401         let sock_addr = IpSocketAddress::new(local_address, port);
402 
403         match f(sock_addr) {
404             Ok(_) => return Ok(sock_addr),
405             Err(e) if i >= MAX_ATTEMPTS => return Err(e),
406             // Try again if the port is already taken. This can sometimes show up as `AccessDenied` on Windows.
407             Err(ErrorCode::AddressInUse | ErrorCode::AccessDenied) => {}
408             Err(e) => return Err(e),
409         }
410     }
411 }
412