1 use futures::join;
2 use test_programs::p3::wasi::sockets::types::{
3     ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket,
4 };
5 use test_programs::sockets::supports_ipv6;
6 
7 struct Component;
8 
9 test_programs::p3::export!(Component);
10 
11 const SOME_PORT: u16 = 47; // If the tests pass, this will never actually be connected to.
12 
13 /// `0.0.0.0` / `::` is not a valid remote address in WASI.
test_tcp_connect_unspec(family: IpAddressFamily)14 async fn test_tcp_connect_unspec(family: IpAddressFamily) {
15     let addr = IpSocketAddress::new(IpAddress::new_unspecified(family), SOME_PORT);
16     let sock = TcpSocket::create(family).unwrap();
17 
18     assert!(matches!(
19         sock.connect(addr).await,
20         Err(ErrorCode::InvalidArgument)
21     ));
22 }
23 
24 /// 0 is not a valid remote port.
test_tcp_connect_port_0(family: IpAddressFamily)25 async fn test_tcp_connect_port_0(family: IpAddressFamily) {
26     let addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
27     let sock = TcpSocket::create(family).unwrap();
28 
29     assert!(matches!(
30         sock.connect(addr).await,
31         Err(ErrorCode::InvalidArgument)
32     ));
33 }
34 
35 /// Connect should validate the address family.
test_tcp_connect_wrong_family(family: IpAddressFamily)36 async fn test_tcp_connect_wrong_family(family: IpAddressFamily) {
37     let wrong_ip = match family {
38         IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK,
39         IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK,
40     };
41     let remote_addr = IpSocketAddress::new(wrong_ip, SOME_PORT);
42 
43     let sock = TcpSocket::create(family).unwrap();
44 
45     assert!(matches!(
46         sock.connect(remote_addr).await,
47         Err(ErrorCode::InvalidArgument)
48     ));
49 }
50 
51 /// Can only connect to unicast addresses.
test_tcp_connect_non_unicast()52 async fn test_tcp_connect_non_unicast() {
53     let ipv4_broadcast = IpSocketAddress::new(IpAddress::IPV4_BROADCAST, SOME_PORT);
54     let ipv4_multicast = IpSocketAddress::new(IpAddress::Ipv4((224, 254, 0, 0)), SOME_PORT);
55     let ipv6_multicast =
56         IpSocketAddress::new(IpAddress::Ipv6((0xff00, 0, 0, 0, 0, 0, 0, 0)), SOME_PORT);
57 
58     let sock_v4 = TcpSocket::create(IpAddressFamily::Ipv4).unwrap();
59     let sock_v6 = TcpSocket::create(IpAddressFamily::Ipv6).unwrap();
60 
61     assert!(matches!(
62         sock_v4.connect(ipv4_broadcast).await,
63         Err(ErrorCode::InvalidArgument)
64     ));
65     assert!(matches!(
66         sock_v4.connect(ipv4_multicast).await,
67         Err(ErrorCode::InvalidArgument)
68     ));
69     assert!(matches!(
70         sock_v6.connect(ipv6_multicast).await,
71         Err(ErrorCode::InvalidArgument)
72     ));
73 }
74 
test_tcp_connect_dual_stack()75 async fn test_tcp_connect_dual_stack() {
76     // Set-up:
77     let v4_listener = TcpSocket::create(IpAddressFamily::Ipv4).unwrap();
78     v4_listener
79         .bind(IpSocketAddress::new(IpAddress::IPV4_LOOPBACK, 0))
80         .unwrap();
81     v4_listener.listen().unwrap();
82 
83     let v4_listener_addr = v4_listener.get_local_address().unwrap();
84     let v6_listener_addr =
85         IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, v4_listener_addr.port());
86 
87     let v6_client = TcpSocket::create(IpAddressFamily::Ipv6).unwrap();
88 
89     // Tests:
90 
91     // Connecting to an IPv4 address on an IPv6 socket should fail:
92     assert!(matches!(
93         v6_client.connect(v4_listener_addr).await,
94         Err(ErrorCode::InvalidArgument)
95     ));
96     // Connecting to an IPv4-mapped-IPv6 address on an IPv6 socket should fail:
97     assert!(matches!(
98         v6_client.connect(v6_listener_addr).await,
99         Err(ErrorCode::InvalidArgument)
100     ));
101 }
102 
103 /// Client sockets can be explicitly bound.
test_tcp_connect_explicit_bind(family: IpAddressFamily)104 async fn test_tcp_connect_explicit_bind(family: IpAddressFamily) {
105     let ip = IpAddress::new_loopback(family);
106 
107     let (listener, mut accept) = {
108         let bind_address = IpSocketAddress::new(ip, 0);
109         let listener = TcpSocket::create(family).unwrap();
110         listener.bind(bind_address).unwrap();
111         let accept = listener.listen().unwrap();
112         (listener, accept)
113     };
114 
115     let listener_address = listener.get_local_address().unwrap();
116 
117     // Connect should work:
118     join!(
119         async {
120             let client = TcpSocket::create(family).unwrap();
121             client
122                 .bind(IpSocketAddress::new(IpAddress::new_unspecified(family), 0))
123                 .unwrap();
124             println!("local address: {:?}", client.get_local_address().unwrap());
125             client.connect(listener_address).await.unwrap();
126             println!("local address: {:?}", client.get_local_address().unwrap());
127         },
128         async {
129             accept.next().await.unwrap();
130         }
131     );
132 }
133 
134 /// Connecting a TCP socket should update the local address to reflect the best
135 /// network path.
test_tcp_connect_local_address_change(family: IpAddressFamily)136 async fn test_tcp_connect_local_address_change(family: IpAddressFamily) {
137     let ip_unspec = IpAddress::new_unspecified(family);
138     let ip_loopback = IpAddress::new_loopback(family);
139 
140     let (listener, mut accept) = {
141         let bind_address = IpSocketAddress::new(ip_loopback, 0);
142         let listener = TcpSocket::create(family).unwrap();
143         listener.bind(bind_address).unwrap();
144         let accept = listener.listen().unwrap();
145         (listener, accept)
146     };
147 
148     join!(
149         async {
150             let listener_address = listener.get_local_address().unwrap();
151             let client = TcpSocket::create(family).unwrap();
152             client.bind(IpSocketAddress::new(ip_unspec, 0)).unwrap();
153 
154             let before = client.get_local_address().unwrap();
155             client.connect(listener_address).await.unwrap();
156             let after = client.get_local_address().unwrap();
157 
158             println!("local address changed from {before:?} to {after:?}");
159 
160             // Note: these assertions are based on observed behavior on Linux,
161             // MacOS and Windows, but there is nothing in their official
162             // documentation to corroborate this.
163             assert_eq!(before.ip(), ip_unspec);
164             assert_eq!(after.ip(), ip_loopback);
165             assert_eq!(before.port(), after.port());
166         },
167         async {
168             accept.next().await.unwrap();
169         }
170     );
171 }
172 
173 impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>174     async fn run() -> Result<(), ()> {
175         test_tcp_connect_unspec(IpAddressFamily::Ipv4).await;
176         test_tcp_connect_port_0(IpAddressFamily::Ipv4).await;
177         test_tcp_connect_wrong_family(IpAddressFamily::Ipv4).await;
178         test_tcp_connect_explicit_bind(IpAddressFamily::Ipv4).await;
179         test_tcp_connect_local_address_change(IpAddressFamily::Ipv4).await;
180 
181         if supports_ipv6() {
182             test_tcp_connect_unspec(IpAddressFamily::Ipv6).await;
183             test_tcp_connect_port_0(IpAddressFamily::Ipv6).await;
184             test_tcp_connect_wrong_family(IpAddressFamily::Ipv6).await;
185             test_tcp_connect_non_unicast().await;
186             test_tcp_connect_dual_stack().await;
187             test_tcp_connect_explicit_bind(IpAddressFamily::Ipv6).await;
188             test_tcp_connect_local_address_change(IpAddressFamily::Ipv6).await;
189         }
190         Ok(())
191     }
192 }
193 
main()194 fn main() {}
195