1 use futures::join;
2 use test_programs::p3::sockets::attempt_random_port;
3 use test_programs::p3::wasi::sockets::types::{
4     ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket,
5 };
6 use test_programs::p3::wit_stream;
7 use test_programs::sockets::supports_ipv6;
8 
9 struct Component;
10 
11 test_programs::p3::export!(Component);
12 
13 /// Bind a socket and let the system determine a port.
test_tcp_bind_ephemeral_port(ip: IpAddress)14 fn test_tcp_bind_ephemeral_port(ip: IpAddress) {
15     let bind_addr = IpSocketAddress::new(ip, 0);
16 
17     let sock = TcpSocket::create(ip.family()).unwrap();
18     sock.bind(bind_addr).unwrap();
19 
20     let bound_addr = sock.get_local_address().unwrap();
21 
22     assert_eq!(bind_addr.ip(), bound_addr.ip());
23     assert_ne!(bind_addr.port(), bound_addr.port());
24 }
25 
26 /// Bind a socket on a specified port.
test_tcp_bind_specific_port(ip: IpAddress)27 fn test_tcp_bind_specific_port(ip: IpAddress) {
28     let sock = TcpSocket::create(ip.family()).unwrap();
29 
30     let bind_addr = attempt_random_port(ip, |bind_addr| sock.bind(bind_addr)).unwrap();
31 
32     let bound_addr = sock.get_local_address().unwrap();
33 
34     assert_eq!(bind_addr.ip(), bound_addr.ip());
35     assert_eq!(bind_addr.port(), bound_addr.port());
36 }
37 
38 /// Two sockets may not be actively bound to the same address at the same time.
test_tcp_bind_addrinuse(ip: IpAddress)39 fn test_tcp_bind_addrinuse(ip: IpAddress) {
40     let bind_addr = IpSocketAddress::new(ip, 0);
41 
42     let sock1 = TcpSocket::create(ip.family()).unwrap();
43     sock1.bind(bind_addr).unwrap();
44     sock1.listen().unwrap();
45 
46     let bound_addr = sock1.get_local_address().unwrap();
47 
48     let sock2 = TcpSocket::create(ip.family()).unwrap();
49     assert!(matches!(
50         sock2.bind(bound_addr),
51         Err(ErrorCode::AddressInUse)
52     ));
53 }
54 
55 // The WASI runtime should set SO_REUSEADDR for us
test_tcp_bind_reuseaddr(ip: IpAddress)56 async fn test_tcp_bind_reuseaddr(ip: IpAddress) {
57     let client = TcpSocket::create(ip.family()).unwrap();
58 
59     let bind_addr = {
60         let listener1 = TcpSocket::create(ip.family()).unwrap();
61 
62         listener1
63             .bind(IpSocketAddress::new(
64                 IpAddress::new_loopback(ip.family()),
65                 0,
66             ))
67             .unwrap();
68 
69         let bind_addr = listener1.get_local_address().unwrap();
70 
71         // The listener socket must have at least one connection for the TIME_WAIT
72         // mechanism to kick in. So we'll create & accept a dummy connection
73         // before closing the listener:
74         {
75             let mut accept = listener1.listen().unwrap();
76 
77             let connect_addr =
78                 IpSocketAddress::new(IpAddress::new_loopback(ip.family()), bind_addr.port());
79             join!(
80                 async {
81                     client.connect(connect_addr).await.unwrap();
82                 },
83                 async {
84                     let sock = accept.next().await.unwrap();
85                     let (mut data_tx, data_rx) = wit_stream::new();
86                     join!(
87                         async {
88                             sock.send(data_rx).await.unwrap();
89                         },
90                         async {
91                             let remaining = data_tx.write_all(vec![0; 10]).await;
92                             assert!(remaining.is_empty());
93                             drop(data_tx);
94                         }
95                     );
96                 },
97             );
98         }
99 
100         bind_addr
101     };
102 
103     // If SO_REUSEADDR was configured correctly, the following lines
104     // shouldn't be affected by the TIME_WAIT state of the just closed
105     // `listener1` socket:
106     let listener2 = TcpSocket::create(ip.family()).unwrap();
107     listener2.bind(bind_addr).unwrap();
108     listener2.listen().unwrap();
109 }
110 
111 // Try binding to an address that is not configured on the system.
test_tcp_bind_addrnotavail(ip: IpAddress)112 fn test_tcp_bind_addrnotavail(ip: IpAddress) {
113     let bind_addr = IpSocketAddress::new(ip, 0);
114 
115     let sock = TcpSocket::create(ip.family()).unwrap();
116 
117     assert!(matches!(
118         sock.bind(bind_addr),
119         Err(ErrorCode::AddressNotBindable)
120     ));
121 }
122 
123 /// Bind should validate the address family.
test_tcp_bind_wrong_family(family: IpAddressFamily)124 fn test_tcp_bind_wrong_family(family: IpAddressFamily) {
125     let wrong_ip = match family {
126         IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK,
127         IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK,
128     };
129 
130     let sock = TcpSocket::create(family).unwrap();
131     let result = sock.bind(IpSocketAddress::new(wrong_ip, 0));
132 
133     assert!(matches!(result, Err(ErrorCode::InvalidArgument)));
134 }
135 
136 /// Bind only works on unicast addresses.
test_tcp_bind_non_unicast()137 fn test_tcp_bind_non_unicast() {
138     let ipv4_broadcast = IpSocketAddress::new(IpAddress::IPV4_BROADCAST, 0);
139     let ipv4_multicast = IpSocketAddress::new(IpAddress::Ipv4((224, 254, 0, 0)), 0);
140     let ipv6_multicast = IpSocketAddress::new(IpAddress::Ipv6((0xff00, 0, 0, 0, 0, 0, 0, 0)), 0);
141 
142     let sock_v4 = TcpSocket::create(IpAddressFamily::Ipv4).unwrap();
143     let sock_v6 = TcpSocket::create(IpAddressFamily::Ipv6).unwrap();
144 
145     assert!(matches!(
146         sock_v4.bind(ipv4_broadcast),
147         Err(ErrorCode::InvalidArgument)
148     ));
149     assert!(matches!(
150         sock_v4.bind(ipv4_multicast),
151         Err(ErrorCode::InvalidArgument)
152     ));
153     assert!(matches!(
154         sock_v6.bind(ipv6_multicast),
155         Err(ErrorCode::InvalidArgument)
156     ));
157 }
158 
test_tcp_bind_dual_stack()159 fn test_tcp_bind_dual_stack() {
160     let sock = TcpSocket::create(IpAddressFamily::Ipv6).unwrap();
161     let addr = IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, 0);
162 
163     // Binding an IPv4-mapped-IPv6 address on a ipv6-only socket should fail:
164     assert!(matches!(sock.bind(addr), Err(ErrorCode::InvalidArgument)));
165 }
166 
167 impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>168     async fn run() -> Result<(), ()> {
169         const RESERVED_IPV4_ADDRESS: IpAddress = IpAddress::Ipv4((192, 0, 2, 0)); // Reserved for documentation and examples.
170         const RESERVED_IPV6_ADDRESS: IpAddress =
171             IpAddress::Ipv6((0x2001, 0x0db8, 0, 0, 0, 0, 0, 0)); // Reserved for documentation and examples.
172 
173         test_tcp_bind_ephemeral_port(IpAddress::IPV4_LOOPBACK);
174         test_tcp_bind_ephemeral_port(IpAddress::IPV4_UNSPECIFIED);
175         test_tcp_bind_specific_port(IpAddress::IPV4_LOOPBACK);
176         test_tcp_bind_specific_port(IpAddress::IPV4_UNSPECIFIED);
177         test_tcp_bind_reuseaddr(IpAddress::IPV4_LOOPBACK).await;
178         test_tcp_bind_addrinuse(IpAddress::IPV4_LOOPBACK);
179         test_tcp_bind_addrinuse(IpAddress::IPV4_UNSPECIFIED);
180         test_tcp_bind_addrnotavail(RESERVED_IPV4_ADDRESS);
181         test_tcp_bind_wrong_family(IpAddressFamily::Ipv4);
182 
183         if supports_ipv6() {
184             test_tcp_bind_ephemeral_port(IpAddress::IPV6_LOOPBACK);
185             test_tcp_bind_ephemeral_port(IpAddress::IPV6_UNSPECIFIED);
186             test_tcp_bind_specific_port(IpAddress::IPV6_LOOPBACK);
187             test_tcp_bind_specific_port(IpAddress::IPV6_UNSPECIFIED);
188             test_tcp_bind_reuseaddr(IpAddress::IPV6_LOOPBACK).await;
189             test_tcp_bind_addrinuse(IpAddress::IPV6_LOOPBACK);
190             test_tcp_bind_addrinuse(IpAddress::IPV6_UNSPECIFIED);
191             test_tcp_bind_addrnotavail(RESERVED_IPV6_ADDRESS);
192             test_tcp_bind_wrong_family(IpAddressFamily::Ipv6);
193             test_tcp_bind_non_unicast();
194             test_tcp_bind_dual_stack();
195         }
196 
197         Ok(())
198     }
199 }
200 
main()201 fn main() {}
202