1 use test_programs::sockets::supports_ipv6;
2 use test_programs::wasi::io::streams::{InputStream, OutputStream, StreamError};
3 use test_programs::wasi::sockets::network::{IpAddress, IpAddressFamily, IpSocketAddress, Network};
4 use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
5 
6 /// InputStream::read should return `StreamError::Closed` after the connection has been shut down by the server.
test_tcp_input_stream_should_be_closed_by_remote_shutdown( net: &Network, family: IpAddressFamily, )7 fn test_tcp_input_stream_should_be_closed_by_remote_shutdown(
8     net: &Network,
9     family: IpAddressFamily,
10 ) {
11     setup(net, family, |server, client| {
12         // Shut down the connection from the server side:
13         server.socket.shutdown(ShutdownType::Both).unwrap();
14         drop(server);
15 
16         // Wait for the shutdown signal to reach the client:
17         client.input.subscribe().block();
18 
19         // The input stream should immediately signal StreamError::Closed.
20         // Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK)
21         // See: https://github.com/bytecodealliance/wasmtime/pull/8968
22         assert!(matches!(client.input.read(10), Err(StreamError::Closed)));
23 
24         // Stream should still be closed, even when requesting 0 bytes:
25         assert!(matches!(client.input.read(0), Err(StreamError::Closed)));
26     });
27 }
28 
29 /// InputStream::read should return `StreamError::Closed` after the connection has been shut down locally.
test_tcp_input_stream_should_be_closed_by_local_shutdown( net: &Network, family: IpAddressFamily, )30 fn test_tcp_input_stream_should_be_closed_by_local_shutdown(
31     net: &Network,
32     family: IpAddressFamily,
33 ) {
34     setup(net, family, |server, client| {
35         // On Linux, `recv` continues to work even after `shutdown(sock, SHUT_RD)`
36         // has been called. To properly test that this behavior doesn't happen in
37         // WASI, we make sure there's some data to read by the client:
38         server.output.blocking_write_util(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.").unwrap();
39 
40         // Wait for the data to reach the client:
41         client.input.subscribe().block();
42 
43         // Shut down socket locally:
44         client.socket.shutdown(ShutdownType::Receive).unwrap();
45 
46         // The input stream should immediately signal StreamError::Closed.
47         assert!(matches!(client.input.read(10), Err(StreamError::Closed)));
48 
49         // Stream should still be closed, even when requesting 0 bytes:
50         assert!(matches!(client.input.read(0), Err(StreamError::Closed)));
51     });
52 }
53 
54 /// OutputStream should return `StreamError::Closed` after the connection has been locally shut down for sending.
test_tcp_output_stream_should_be_closed_by_local_shutdown( net: &Network, family: IpAddressFamily, )55 fn test_tcp_output_stream_should_be_closed_by_local_shutdown(
56     net: &Network,
57     family: IpAddressFamily,
58 ) {
59     setup(net, family, |_server, client| {
60         let message = b"Hi!";
61 
62         // The stream should be writable:
63         assert!(client.output.check_write().unwrap() as usize >= message.len());
64 
65         // Perform the shutdown
66         client.socket.shutdown(ShutdownType::Send).unwrap();
67 
68         // Stream should be closed:
69         assert!(matches!(
70             client.output.write(message),
71             Err(StreamError::Closed)
72         ));
73 
74         // The stream should remain closed:
75         assert!(matches!(
76             client.output.check_write(),
77             Err(StreamError::Closed)
78         ));
79         assert!(matches!(client.output.flush(), Err(StreamError::Closed)));
80     });
81 }
82 
83 /// Calling `shutdown` while the OutputStream is in the middle of a background write should not cause that write to be lost.
test_tcp_shutdown_should_not_lose_data(net: &Network, family: IpAddressFamily)84 fn test_tcp_shutdown_should_not_lose_data(net: &Network, family: IpAddressFamily) {
85     setup(net, family, |server, client| {
86         // Minimize the local send buffer:
87         client.socket.set_send_buffer_size(1024).unwrap();
88         let small_buffer_size = client.socket.send_buffer_size().unwrap();
89 
90         // Create a significantly bigger buffer, so that we can be pretty sure the `write` won't finish immediately:
91         let big_buffer_size = client
92             .output
93             .check_write()
94             .unwrap()
95             .min(100 * small_buffer_size);
96         assert!(big_buffer_size > small_buffer_size);
97         let outgoing_data = vec![0; big_buffer_size as usize];
98 
99         // Submit the oversized buffer and immediately initiate the shutdown:
100         client.output.write(&outgoing_data).unwrap();
101         client.socket.shutdown(ShutdownType::Send).unwrap();
102 
103         // The peer should receive _all_ data:
104         let incoming_data = server.input.blocking_read_to_end().unwrap();
105         assert_eq!(
106             outgoing_data.len(),
107             incoming_data.len(),
108             "Received data should match the sent data"
109         );
110     });
111 }
112 
main()113 fn main() {
114     let net = Network::default();
115 
116     test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv4);
117     test_tcp_input_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv4);
118     test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv4);
119     test_tcp_shutdown_should_not_lose_data(&net, IpAddressFamily::Ipv4);
120 
121     if supports_ipv6() {
122         test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv6);
123         test_tcp_input_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv6);
124         test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv6);
125         test_tcp_shutdown_should_not_lose_data(&net, IpAddressFamily::Ipv6);
126     }
127 }
128 
129 struct Connection {
130     input: InputStream,
131     output: OutputStream,
132     socket: TcpSocket,
133 }
134 
135 /// Set up a connected pair of sockets
setup(net: &Network, family: IpAddressFamily, body: impl FnOnce(Connection, Connection))136 fn setup(net: &Network, family: IpAddressFamily, body: impl FnOnce(Connection, Connection)) {
137     let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
138     let listener = TcpSocket::new(family).unwrap();
139     listener.blocking_bind(&net, bind_address).unwrap();
140     listener.blocking_listen().unwrap();
141     let bound_address = listener.local_address().unwrap();
142     let client_socket = TcpSocket::new(family).unwrap();
143     let (client_input, client_output) = client_socket.blocking_connect(net, bound_address).unwrap();
144     let (accepted_socket, accepted_input, accepted_output) = listener.blocking_accept().unwrap();
145 
146     body(
147         Connection {
148             input: accepted_input,
149             output: accepted_output,
150             socket: accepted_socket,
151         },
152         Connection {
153             input: client_input,
154             output: client_output,
155             socket: client_socket,
156         },
157     );
158 }
159