1 use test_programs::{
2     p3::wasi::sockets::types::{
3         ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4Address, Ipv6Address, UdpSocket,
4     },
5     sockets::supports_ipv6,
6 };
7 
8 struct Component;
9 
10 test_programs::p3::export!(Component);
11 
12 // If the tests work as expected, these will never actually be connected to:
13 const SOME_PORT: u16 = 47;
14 const SOME_PUBLIC_IPV4: Ipv4Address = (123, 234, 12, 34);
15 const SOME_PUBLIC_IPV6: Ipv6Address = (123, 234, 0, 0, 0, 0, 0, 34);
16 
test_udp_connect_disconnect_reconnect(family: IpAddressFamily)17 fn test_udp_connect_disconnect_reconnect(family: IpAddressFamily) {
18     let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321);
19     let remote2 = IpSocketAddress::new(IpAddress::new_loopback(family), 4320);
20 
21     let client = UdpSocket::create(family).unwrap();
22 
23     assert!(matches!(client.disconnect(), Err(ErrorCode::InvalidState)));
24     assert!(matches!(
25         client.get_remote_address(),
26         Err(ErrorCode::InvalidState)
27     ));
28 
29     assert!(matches!(client.disconnect(), Err(ErrorCode::InvalidState)));
30     assert!(matches!(
31         client.get_remote_address(),
32         Err(ErrorCode::InvalidState)
33     ));
34 
35     _ = client.connect(remote1).unwrap();
36     assert_eq!(client.get_remote_address().unwrap(), remote1);
37 
38     _ = client.connect(remote1).unwrap();
39     assert_eq!(client.get_remote_address().unwrap(), remote1);
40 
41     _ = client.connect(remote2).unwrap();
42     assert_eq!(client.get_remote_address().unwrap(), remote2);
43 
44     _ = client.disconnect().unwrap();
45     assert!(matches!(
46         client.get_remote_address(),
47         Err(ErrorCode::InvalidState)
48     ));
49 
50     _ = client.connect(remote1).unwrap();
51     assert_eq!(client.get_remote_address().unwrap(), remote1);
52 }
53 
54 /// `0.0.0.0` / `::` is not a valid remote address in WASI.
test_udp_connect_unspec(family: IpAddressFamily)55 fn test_udp_connect_unspec(family: IpAddressFamily) {
56     let ip = IpAddress::new_unspecified(family);
57     let addr = IpSocketAddress::new(ip, SOME_PORT);
58     let sock = UdpSocket::create(family).unwrap();
59 
60     assert!(matches!(
61         sock.connect(addr),
62         Err(ErrorCode::InvalidArgument)
63     ));
64 }
65 
66 /// If not explicitly bound, connecting a UDP socket should update the local
67 /// address to reflect the best network path.
test_udp_connect_local_address_change(family: IpAddressFamily)68 fn test_udp_connect_local_address_change(family: IpAddressFamily) {
69     fn connect(sock: &UdpSocket, ip: IpAddress, port: u16) -> IpSocketAddress {
70         let remote = IpSocketAddress::new(ip, port);
71         sock.connect(remote).unwrap();
72         let local = sock.get_local_address().unwrap();
73         println!("connect({remote:?}) changed local address to: {local:?}",);
74         local
75     }
76 
77     if !has_public_interface(family) {
78         println!("No public interface detected, skipping test");
79         return;
80     }
81 
82     let loopback_ip = IpAddress::new_loopback(family);
83     let public_ip = some_public_ip(family);
84 
85     let client = UdpSocket::create(family).unwrap();
86 
87     let loopback_if1 = connect(&client, loopback_ip, 4321);
88     let loopback_if2 = connect(&client, loopback_ip, 4322);
89     let public_if = connect(&client, public_ip, 4323);
90 
91     // Note: these assertions are based on observed behavior on Linux, MacOS and
92     // Windows, but there is nothing in their official documentation to
93     // corroborate this.
94     assert_eq!(loopback_if1, loopback_if2);
95     assert_ne!(loopback_if1, public_if);
96 }
97 
98 /// 0 is not a valid remote port.
test_udp_connect_port_0(family: IpAddressFamily)99 fn test_udp_connect_port_0(family: IpAddressFamily) {
100     let addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
101     let sock = UdpSocket::create(family).unwrap();
102 
103     assert!(matches!(
104         sock.connect(addr),
105         Err(ErrorCode::InvalidArgument)
106     ));
107 }
108 
109 /// Connect should validate the address family.
test_udp_connect_wrong_family(family: IpAddressFamily)110 fn test_udp_connect_wrong_family(family: IpAddressFamily) {
111     let wrong_ip = match family {
112         IpAddressFamily::Ipv4 => IpAddress::IPV6_LOOPBACK,
113         IpAddressFamily::Ipv6 => IpAddress::IPV4_LOOPBACK,
114     };
115     let remote_addr = IpSocketAddress::new(wrong_ip, SOME_PORT);
116 
117     let sock = UdpSocket::create(family).unwrap();
118 
119     assert!(matches!(
120         sock.connect(remote_addr),
121         Err(ErrorCode::InvalidArgument)
122     ));
123 }
124 
125 /// Connect should perform implicit bind.
test_udp_connect_without_bind(family: IpAddressFamily)126 fn test_udp_connect_without_bind(family: IpAddressFamily) {
127     let remote_addr = IpSocketAddress::new(IpAddress::new_loopback(family), SOME_PORT);
128 
129     let sock = UdpSocket::create(family).unwrap();
130 
131     assert!(matches!(sock.get_local_address(), Err(_)));
132     assert!(matches!(sock.connect(remote_addr), Ok(_)));
133     assert!(matches!(sock.get_local_address(), Ok(_)));
134 }
135 
136 /// Connect should work in combination with an explicit bind.
test_udp_connect_with_bind(family: IpAddressFamily)137 fn test_udp_connect_with_bind(family: IpAddressFamily) {
138     let remote_addr = IpSocketAddress::new(IpAddress::new_loopback(family), SOME_PORT);
139 
140     let sock = UdpSocket::create(family).unwrap();
141 
142     sock.bind_unspecified().unwrap();
143 
144     assert!(matches!(sock.get_local_address(), Ok(_)));
145     assert!(matches!(sock.connect(remote_addr), Ok(_)));
146     assert!(matches!(sock.get_local_address(), Ok(_)));
147 }
148 
test_udp_connect_dual_stack()149 fn test_udp_connect_dual_stack() {
150     // Set-up:
151     let v4_server = UdpSocket::create(IpAddressFamily::Ipv4).unwrap();
152     v4_server
153         .bind(IpSocketAddress::new(IpAddress::IPV4_LOOPBACK, 0))
154         .unwrap();
155 
156     let v4_server_addr = v4_server.get_local_address().unwrap();
157     let v6_server_addr =
158         IpSocketAddress::new(IpAddress::IPV4_MAPPED_LOOPBACK, v4_server_addr.port());
159 
160     // Tests:
161     let v6_client = UdpSocket::create(IpAddressFamily::Ipv6).unwrap();
162 
163     v6_client.bind_unspecified().unwrap();
164 
165     // Connecting to an IPv4 address on an IPv6 socket should fail:
166     assert!(matches!(
167         v6_client.connect(v4_server_addr),
168         Err(ErrorCode::InvalidArgument)
169     ));
170 
171     // Connecting to an IPv4-mapped-IPv6 address on an IPv6 socket should fail:
172     assert!(matches!(
173         v6_client.connect(v6_server_addr),
174         Err(ErrorCode::InvalidArgument)
175     ));
176 }
177 
178 /// A UDP socket should be immediately writable
test_udp_connect_and_send(family: IpAddressFamily)179 async fn test_udp_connect_and_send(family: IpAddressFamily) {
180     let unspecified_port = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
181     let remote = IpSocketAddress::new(IpAddress::new_loopback(family), 4320);
182 
183     let client = UdpSocket::create(family).unwrap();
184     client.bind(unspecified_port).unwrap();
185 
186     client.connect(remote).unwrap();
187     assert_eq!(client.get_remote_address().unwrap(), remote);
188 
189     client.send(b"hello".into(), None).await.unwrap();
190 }
191 
192 impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>193     async fn run() -> Result<(), ()> {
194         let supports_ipv6 = supports_ipv6();
195 
196         test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv4);
197         test_udp_connect_unspec(IpAddressFamily::Ipv4);
198         test_udp_connect_local_address_change(IpAddressFamily::Ipv4);
199         test_udp_connect_port_0(IpAddressFamily::Ipv4);
200         test_udp_connect_wrong_family(IpAddressFamily::Ipv4);
201         test_udp_connect_without_bind(IpAddressFamily::Ipv4);
202         test_udp_connect_with_bind(IpAddressFamily::Ipv4);
203         test_udp_connect_and_send(IpAddressFamily::Ipv4).await;
204 
205         if supports_ipv6 {
206             test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv6);
207             test_udp_connect_unspec(IpAddressFamily::Ipv6);
208             test_udp_connect_local_address_change(IpAddressFamily::Ipv6);
209             test_udp_connect_port_0(IpAddressFamily::Ipv6);
210             test_udp_connect_wrong_family(IpAddressFamily::Ipv6);
211             test_udp_connect_without_bind(IpAddressFamily::Ipv6);
212             test_udp_connect_with_bind(IpAddressFamily::Ipv6);
213             test_udp_connect_and_send(IpAddressFamily::Ipv6).await;
214             test_udp_connect_dual_stack();
215         }
216 
217         Ok(())
218     }
219 }
220 
some_public_ip(family: IpAddressFamily) -> IpAddress221 fn some_public_ip(family: IpAddressFamily) -> IpAddress {
222     match family {
223         IpAddressFamily::Ipv4 => IpAddress::Ipv4(SOME_PUBLIC_IPV4),
224         IpAddressFamily::Ipv6 => IpAddress::Ipv6(SOME_PUBLIC_IPV6),
225     }
226 }
227 
has_public_interface(family: IpAddressFamily) -> bool228 fn has_public_interface(family: IpAddressFamily) -> bool {
229     let sock = UdpSocket::create(family).unwrap();
230     sock.connect(IpSocketAddress::new(some_public_ip(family), SOME_PORT))
231         .is_ok()
232 }
233 
main()234 fn main() {}
235