1 use crate::wasi::clocks::monotonic_clock;
2 use crate::wasi::io::poll::{self, Pollable};
3 use crate::wasi::io::streams::{InputStream, OutputStream, StreamError};
4 use crate::wasi::random;
5 use crate::wasi::sockets::instance_network;
6 use crate::wasi::sockets::ip_name_lookup;
7 use crate::wasi::sockets::network::{
8     ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress,
9     Network,
10 };
11 use crate::wasi::sockets::tcp::TcpSocket;
12 use crate::wasi::sockets::udp::{
13     IncomingDatagram, IncomingDatagramStream, OutgoingDatagram, OutgoingDatagramStream, UdpSocket,
14 };
15 use crate::wasi::sockets::{tcp_create_socket, udp_create_socket};
16 use std::ops::Range;
17 
18 const TIMEOUT_NS: u64 = 1_000_000_000;
19 
20 impl Pollable {
21     pub fn block_until(&self, timeout: &Pollable) -> Result<(), ErrorCode> {
22         let ready = poll::poll(&[self, timeout]);
23         assert!(ready.len() > 0);
24         match ready[0] {
25             0 => Ok(()),
26             1 => Err(ErrorCode::Timeout),
27             _ => unreachable!(),
28         }
29     }
30 }
31 
32 impl InputStream {
33     pub fn blocking_read_to_end(&self) -> Result<Vec<u8>, crate::wasi::io::error::Error> {
34         let mut data = vec![];
35         loop {
36             match self.blocking_read(1024 * 1024) {
37                 Ok(chunk) => data.extend(chunk),
38                 Err(StreamError::Closed) => return Ok(data),
39                 Err(StreamError::LastOperationFailed(e)) => return Err(e),
40             }
41         }
42     }
43 }
44 
45 impl OutputStream {
46     pub fn blocking_write_util(&self, mut bytes: &[u8]) -> Result<(), StreamError> {
47         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
48         let pollable = self.subscribe();
49 
50         while !bytes.is_empty() {
51             pollable.block_until(&timeout).expect("write timed out");
52 
53             let permit = self.check_write()?;
54 
55             let len = bytes.len().min(permit as usize);
56             let (chunk, rest) = bytes.split_at(len);
57 
58             self.write(chunk)?;
59 
60             self.blocking_flush()?;
61 
62             bytes = rest;
63         }
64         Ok(())
65     }
66 }
67 
68 impl Network {
69     pub fn default() -> Network {
70         instance_network::instance_network()
71     }
72 
73     pub fn blocking_resolve_addresses(&self, name: &str) -> Result<Vec<IpAddress>, ErrorCode> {
74         let stream = ip_name_lookup::resolve_addresses(&self, name)?;
75 
76         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
77         let pollable = stream.subscribe();
78 
79         let mut addresses = vec![];
80 
81         loop {
82             match stream.resolve_next_address() {
83                 Ok(Some(addr)) => {
84                     addresses.push(addr);
85                 }
86                 Ok(None) => match addresses[..] {
87                     [] => return Err(ErrorCode::NameUnresolvable),
88                     _ => return Ok(addresses),
89                 },
90                 Err(ErrorCode::WouldBlock) => {
91                     pollable.block_until(&timeout)?;
92                 }
93                 Err(err) => return Err(err),
94             }
95         }
96     }
97 
98     /// Same as `Network::blocking_resolve_addresses` but ignores post validation errors
99     ///
100     /// The ignored error codes signal that the input passed validation
101     /// and a lookup was actually attempted, but failed. These are ignored to
102     /// make the CI tests less flaky.
103     pub fn permissive_blocking_resolve_addresses(
104         &self,
105         name: &str,
106     ) -> Result<Vec<IpAddress>, ErrorCode> {
107         match self.blocking_resolve_addresses(name) {
108             Err(ErrorCode::NameUnresolvable | ErrorCode::TemporaryResolverFailure) => Ok(vec![]),
109             r => r,
110         }
111     }
112 }
113 
114 impl TcpSocket {
115     pub fn new(address_family: IpAddressFamily) -> Result<TcpSocket, ErrorCode> {
116         tcp_create_socket::create_tcp_socket(address_family)
117     }
118 
119     pub fn blocking_bind(
120         &self,
121         network: &Network,
122         local_address: IpSocketAddress,
123     ) -> Result<(), ErrorCode> {
124         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
125         let sub = self.subscribe();
126 
127         self.start_bind(&network, local_address)?;
128 
129         loop {
130             match self.finish_bind() {
131                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
132                 result => return result,
133             }
134         }
135     }
136 
137     pub fn blocking_listen(&self) -> Result<(), ErrorCode> {
138         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
139         let sub = self.subscribe();
140 
141         self.start_listen()?;
142 
143         loop {
144             match self.finish_listen() {
145                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
146                 result => return result,
147             }
148         }
149     }
150 
151     pub fn blocking_connect(
152         &self,
153         network: &Network,
154         remote_address: IpSocketAddress,
155     ) -> Result<(InputStream, OutputStream), ErrorCode> {
156         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
157         let sub = self.subscribe();
158 
159         self.start_connect(&network, remote_address)?;
160 
161         loop {
162             match self.finish_connect() {
163                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
164                 result => return result,
165             }
166         }
167     }
168 
169     pub fn blocking_accept(&self) -> Result<(TcpSocket, InputStream, OutputStream), ErrorCode> {
170         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
171         let sub = self.subscribe();
172 
173         loop {
174             match self.accept() {
175                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
176                 result => return result,
177             }
178         }
179     }
180 }
181 
182 impl UdpSocket {
183     pub fn new(address_family: IpAddressFamily) -> Result<UdpSocket, ErrorCode> {
184         udp_create_socket::create_udp_socket(address_family)
185     }
186 
187     pub fn blocking_bind(
188         &self,
189         network: &Network,
190         local_address: IpSocketAddress,
191     ) -> Result<(), ErrorCode> {
192         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
193         let sub = self.subscribe();
194 
195         self.start_bind(&network, local_address)?;
196 
197         loop {
198             match self.finish_bind() {
199                 Err(ErrorCode::WouldBlock) => sub.block_until(&timeout)?,
200                 result => return result,
201             }
202         }
203     }
204 
205     pub fn blocking_bind_unspecified(&self, network: &Network) -> Result<(), ErrorCode> {
206         let ip = IpAddress::new_unspecified(self.address_family());
207         let port = 0;
208 
209         self.blocking_bind(network, IpSocketAddress::new(ip, port))
210     }
211 }
212 
213 impl OutgoingDatagramStream {
214     fn blocking_check_send(&self, timeout: &Pollable) -> Result<u64, ErrorCode> {
215         let sub = self.subscribe();
216 
217         loop {
218             match self.check_send() {
219                 Ok(0) => sub.block_until(timeout)?,
220                 result => return result,
221             }
222         }
223     }
224 
225     pub fn blocking_send(&self, mut datagrams: &[OutgoingDatagram]) -> Result<(), ErrorCode> {
226         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
227 
228         while !datagrams.is_empty() {
229             let permit = self.blocking_check_send(&timeout)?;
230             let chunk_len = datagrams.len().min(permit as usize);
231             match self.send(&datagrams[..chunk_len]) {
232                 Ok(0) => {}
233                 Ok(packets_sent) => {
234                     let packets_sent = packets_sent as usize;
235                     datagrams = &datagrams[packets_sent..];
236                 }
237                 Err(err) => return Err(err),
238             }
239         }
240 
241         Ok(())
242     }
243 }
244 
245 impl IncomingDatagramStream {
246     pub fn blocking_receive(&self, count: Range<u64>) -> Result<Vec<IncomingDatagram>, ErrorCode> {
247         let timeout = monotonic_clock::subscribe_duration(TIMEOUT_NS);
248         let pollable = self.subscribe();
249         let mut datagrams = vec![];
250 
251         loop {
252             match self.receive(count.end - datagrams.len() as u64) {
253                 Ok(mut chunk) => {
254                     datagrams.append(&mut chunk);
255 
256                     if datagrams.len() >= count.start as usize {
257                         return Ok(datagrams);
258                     } else {
259                         pollable.block_until(&timeout)?;
260                     }
261                 }
262                 Err(err) => return Err(err),
263             }
264         }
265     }
266 }
267 
268 impl IpAddress {
269     pub const IPV4_BROADCAST: IpAddress = IpAddress::Ipv4((255, 255, 255, 255));
270 
271     pub const IPV4_LOOPBACK: IpAddress = IpAddress::Ipv4((127, 0, 0, 1));
272     pub const IPV6_LOOPBACK: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 1));
273 
274     pub const IPV4_UNSPECIFIED: IpAddress = IpAddress::Ipv4((0, 0, 0, 0));
275     pub const IPV6_UNSPECIFIED: IpAddress = IpAddress::Ipv6((0, 0, 0, 0, 0, 0, 0, 0));
276 
277     pub const IPV4_MAPPED_LOOPBACK: IpAddress =
278         IpAddress::Ipv6((0, 0, 0, 0, 0, 0xFFFF, 0x7F00, 0x0001));
279 
280     pub const fn new_loopback(family: IpAddressFamily) -> IpAddress {
281         match family {
282             IpAddressFamily::Ipv4 => Self::IPV4_LOOPBACK,
283             IpAddressFamily::Ipv6 => Self::IPV6_LOOPBACK,
284         }
285     }
286 
287     pub const fn new_unspecified(family: IpAddressFamily) -> IpAddress {
288         match family {
289             IpAddressFamily::Ipv4 => Self::IPV4_UNSPECIFIED,
290             IpAddressFamily::Ipv6 => Self::IPV6_UNSPECIFIED,
291         }
292     }
293 
294     pub const fn family(&self) -> IpAddressFamily {
295         match self {
296             IpAddress::Ipv4(_) => IpAddressFamily::Ipv4,
297             IpAddress::Ipv6(_) => IpAddressFamily::Ipv6,
298         }
299     }
300 }
301 
302 impl PartialEq for IpAddress {
303     fn eq(&self, other: &Self) -> bool {
304         match (self, other) {
305             (Self::Ipv4(left), Self::Ipv4(right)) => left == right,
306             (Self::Ipv6(left), Self::Ipv6(right)) => left == right,
307             _ => false,
308         }
309     }
310 }
311 
312 impl IpSocketAddress {
313     pub const fn new(ip: IpAddress, port: u16) -> IpSocketAddress {
314         match ip {
315             IpAddress::Ipv4(addr) => IpSocketAddress::Ipv4(Ipv4SocketAddress {
316                 port,
317                 address: addr,
318             }),
319             IpAddress::Ipv6(addr) => IpSocketAddress::Ipv6(Ipv6SocketAddress {
320                 port,
321                 address: addr,
322                 flow_info: 0,
323                 scope_id: 0,
324             }),
325         }
326     }
327 
328     pub const fn ip(&self) -> IpAddress {
329         match self {
330             IpSocketAddress::Ipv4(addr) => IpAddress::Ipv4(addr.address),
331             IpSocketAddress::Ipv6(addr) => IpAddress::Ipv6(addr.address),
332         }
333     }
334 
335     pub const fn port(&self) -> u16 {
336         match self {
337             IpSocketAddress::Ipv4(addr) => addr.port,
338             IpSocketAddress::Ipv6(addr) => addr.port,
339         }
340     }
341 
342     pub const fn family(&self) -> IpAddressFamily {
343         match self {
344             IpSocketAddress::Ipv4(_) => IpAddressFamily::Ipv4,
345             IpSocketAddress::Ipv6(_) => IpAddressFamily::Ipv6,
346         }
347     }
348 }
349 
350 impl PartialEq for Ipv4SocketAddress {
351     fn eq(&self, other: &Self) -> bool {
352         self.port == other.port && self.address == other.address
353     }
354 }
355 
356 impl PartialEq for Ipv6SocketAddress {
357     fn eq(&self, other: &Self) -> bool {
358         self.port == other.port
359             && self.flow_info == other.flow_info
360             && self.address == other.address
361             && self.scope_id == other.scope_id
362     }
363 }
364 
365 impl PartialEq for IpSocketAddress {
366     fn eq(&self, other: &Self) -> bool {
367         match (self, other) {
368             (Self::Ipv4(l0), Self::Ipv4(r0)) => l0 == r0,
369             (Self::Ipv6(l0), Self::Ipv6(r0)) => l0 == r0,
370             _ => false,
371         }
372     }
373 }
374 
375 fn generate_random_u16(range: Range<u16>) -> u16 {
376     let start = range.start as u64;
377     let end = range.end as u64;
378     let port = start + (random::random::get_random_u64() % (end - start));
379     port as u16
380 }
381 
382 /// Execute the inner function with a randomly generated port.
383 /// To prevent random failures, we make a few attempts before giving up.
384 pub fn attempt_random_port<F>(
385     local_address: IpAddress,
386     mut f: F,
387 ) -> Result<IpSocketAddress, ErrorCode>
388 where
389     F: FnMut(IpSocketAddress) -> Result<(), ErrorCode>,
390 {
391     const MAX_ATTEMPTS: u32 = 10;
392     let mut i = 0;
393     loop {
394         i += 1;
395 
396         let port: u16 = generate_random_u16(1024..u16::MAX);
397         let sock_addr = IpSocketAddress::new(local_address, port);
398 
399         match f(sock_addr) {
400             Ok(_) => return Ok(sock_addr),
401             Err(e) if i >= MAX_ATTEMPTS => return Err(e),
402             // Try again if the port is already taken. This can sometimes show up as `AccessDenied` on Windows.
403             Err(ErrorCode::AddressInUse | ErrorCode::AccessDenied) => {}
404             Err(e) => return Err(e),
405         }
406     }
407 }
408