1 use futures::join;
2 use std::pin::pin;
3 use std::task::{Context, Poll, Waker};
4 use test_programs::p3::wasi::sockets::types::{
5     ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket,
6 };
7 use test_programs::p3::wit_stream;
8 use test_programs::sockets::supports_ipv6;
9 use wit_bindgen::{FutureReader, StreamReader, StreamResult, StreamWriter};
10 
11 struct Component;
12 
13 test_programs::p3::export!(Component);
14 
15 /// Test basic functionality.
test_tcp_ping_pong(family: IpAddressFamily)16 async fn test_tcp_ping_pong(family: IpAddressFamily) {
17     setup(family, |mut server, mut client| async move {
18         {
19             let rest = server.send_stream.write_all(b"ping".into()).await;
20             assert!(rest.is_empty());
21         }
22         {
23             let (status, buf) = client.receive_stream.read(Vec::with_capacity(4)).await;
24             assert_eq!(status, StreamResult::Complete(4));
25             assert_eq!(buf, b"ping");
26         }
27         {
28             let rest = client.send_stream.write_all(b"pong".into()).await;
29             assert!(rest.is_empty());
30         }
31         {
32             let (status, buf) = server.receive_stream.read(Vec::with_capacity(4)).await;
33             assert_eq!(status, StreamResult::Complete(4));
34             assert_eq!(buf, b"pong");
35         }
36     })
37     .await;
38 }
39 
40 /// The stream and future returned by `receive` should complete/resolve after
41 /// the connection has been shut down by the remote.
test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(family: IpAddressFamily)42 async fn test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(family: IpAddressFamily) {
43     setup(family, |server, mut client| async move {
44         drop(server);
45 
46         // Wait for the shutdown signal to reach the client:
47         let (stream_result, data) = client.receive_stream.read(Vec::with_capacity(1)).await;
48         assert_eq!(data.len(), 0);
49         assert_eq!(stream_result, StreamResult::Dropped);
50         client.receive_result.await.unwrap();
51     })
52     .await;
53 }
54 
55 /// The future returned by `receive` should resolve once the companion stream
56 /// has been dropped. Regardless of whether there was still data pending.
test_tcp_receive_future_should_resolve_when_stream_dropped(family: IpAddressFamily)57 async fn test_tcp_receive_future_should_resolve_when_stream_dropped(family: IpAddressFamily) {
58     setup(family, |mut server, client| async move {
59         {
60             let rest = server.send_stream.write_all(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.".into()).await;
61             assert!(rest.is_empty());
62         }
63         {
64             let Connection { mut receive_stream, receive_result, .. } = client;
65 
66             // Wait for the data to be ready:
67             receive_stream.next().await.unwrap();
68             drop(receive_stream);
69 
70             // Dropping the stream should've caused the future to resolve even
71             // though there was still data pending:
72             receive_result.await.unwrap();
73         }
74     }).await;
75 }
76 
77 /// The future returned by `send` should resolve after the input stream is dropped.
test_tcp_send_future_should_resolve_when_stream_dropped(family: IpAddressFamily)78 async fn test_tcp_send_future_should_resolve_when_stream_dropped(family: IpAddressFamily) {
79     setup(family, |_server, client| async move {
80         let Connection {
81             send_stream,
82             send_result,
83             ..
84         } = client;
85         drop(send_stream);
86         send_result.await.unwrap();
87     })
88     .await;
89 }
90 
91 /// `send` should drop the input stream when the connection is shut down by the remote.
test_tcp_send_drops_stream_when_remote_shutdown(family: IpAddressFamily)92 async fn test_tcp_send_drops_stream_when_remote_shutdown(family: IpAddressFamily) {
93     setup(family, |server, mut client| async move {
94         drop(server);
95 
96         // Give it a few tries for the shutdown signal to reach the client:
97         loop {
98             let stream_result = client.send_stream.write(b"undeliverable".into()).await.0;
99             if stream_result == StreamResult::Dropped {
100                 break;
101             }
102         }
103 
104         let result = client.send_result.await;
105         assert!(
106             matches!(
107                 result,
108                 Err(ErrorCode::ConnectionBroken | ErrorCode::ConnectionReset)
109             ),
110             "unexpected error {result:?}",
111         );
112     })
113     .await;
114 }
115 
116 /// `receive` may be called successfully at most once.
test_tcp_receive_once(family: IpAddressFamily)117 async fn test_tcp_receive_once(family: IpAddressFamily) {
118     setup(family, |mut server, client| async move {
119         // Give the client some potential data to _hopefully never_ read.
120         {
121             let rest = server.send_stream.write_all(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.".into()).await;
122             assert!(rest.is_empty());
123         }
124 
125         // FYI, the first call to `receive` is part of the `setup` code, so every
126         // `receive` in here should fail.
127         for _ in 0..3 {
128             let (mut reader, future) = client.socket.receive();
129 
130             let (stream_result, data) = reader.read(Vec::with_capacity(10)).await;
131             assert_eq!(data.len(), 0);
132             assert_eq!(stream_result, StreamResult::Dropped);
133             assert!(matches!(future.await, Err(ErrorCode::InvalidState)));
134         }
135     })
136     .await;
137 }
138 
139 /// `send` may be called successfully at most once.
test_tcp_send_once(family: IpAddressFamily)140 async fn test_tcp_send_once(family: IpAddressFamily) {
141     setup(family, |_server, client| async move {
142         // FYI, the first call to `send` is part of the `setup` code, so every
143         // `send` in here should fail.
144         for _ in 0..3 {
145             let (mut writer, send_rx) = wit_stream::new();
146             let future = client.socket.send(send_rx);
147 
148             const DATA: &[u8] = b"undeliverable";
149             let (stream_result, rest) = writer.write(DATA.into()).await;
150             assert_eq!(rest.into_vec(), DATA);
151             assert_eq!(stream_result, StreamResult::Dropped);
152             assert!(matches!(future.await, Err(ErrorCode::InvalidState)));
153         }
154     })
155     .await;
156 }
157 
158 /// The streams and futures returned by `send` and `receive` should remain
159 /// operational even after the socket that spawned them has been dropped.
test_tcp_stream_lifetimes(family: IpAddressFamily)160 async fn test_tcp_stream_lifetimes(family: IpAddressFamily) {
161     setup(family, |server, client| async move {
162         let Connection {
163             socket: server_socket,
164             send_stream: mut server_send_stream,
165             receive_stream: server_receive_stream,
166             send_result: server_send_result,
167             receive_result: server_receive_result,
168         } = server;
169         let Connection {
170             socket: client_socket,
171             send_stream: mut client_send_stream,
172             receive_stream: client_receive_stream,
173             send_result: client_send_result,
174             receive_result: client_receive_result,
175         } = client;
176 
177         // Drop the parent sockets:
178         drop(server_socket);
179         drop(client_socket);
180 
181         {
182             let rest = server_send_stream.write_all(b"ping".into()).await;
183             assert!(rest.is_empty());
184             drop(server_send_stream);
185             server_send_result.await.unwrap();
186         }
187         {
188             let data = client_receive_stream.collect().await;
189             assert_eq!(data, b"ping");
190             client_receive_result.await.unwrap();
191         }
192         {
193             let rest = client_send_stream.write_all(b"pong".into()).await;
194             assert!(rest.is_empty());
195             drop(client_send_stream);
196             client_send_result.await.unwrap();
197         }
198         {
199             let data = server_receive_stream.collect().await;
200             assert_eq!(data, b"pong");
201             server_receive_result.await.unwrap();
202         }
203     })
204     .await;
205 }
206 
207 /// Model a situation where there's a continuous stream of data coming into the
208 /// guest from one side and the other side is reading in chunks but also
209 /// cancelling reads occasionally. Should receive the complete stream of data
210 /// into the result.
test_tcp_read_cancellation(family: IpAddressFamily)211 async fn test_tcp_read_cancellation(family: IpAddressFamily) {
212     // Send 2M of data in 256-byte chunks.
213     const CHUNKS: usize = (2 << 20) / 256;
214     let mut data = [0; 256];
215     for (i, slot) in data.iter_mut().enumerate() {
216         *slot = i as u8;
217     }
218 
219     setup(family, |mut server, mut client| async move {
220         // Minimize the local send buffer:
221         client.socket.set_send_buffer_size(1024).unwrap();
222 
223         join!(
224             async {
225                 for _ in 0..CHUNKS {
226                     let ret = client.send_stream.write_all(data.to_vec()).await;
227                     assert!(ret.is_empty());
228                 }
229                 drop(client.send_stream);
230             },
231             async {
232                 let mut buf = Vec::with_capacity(1024);
233                 let mut i = 0_usize;
234                 let mut consecutive_zero_length_reads = 0;
235                 loop {
236                     assert!(buf.is_empty());
237                     let (status, b) = {
238                         let mut fut = pin!(server.receive_stream.read(buf));
239                         let mut cx = Context::from_waker(Waker::noop());
240                         match fut.as_mut().poll(&mut cx) {
241                             Poll::Ready(pair) => pair,
242                             Poll::Pending => fut.cancel(),
243                         }
244                     };
245                     buf = b;
246                     match status {
247                         StreamResult::Complete(n) => {
248                             assert_eq!(buf.len(), n);
249                             for slot in buf.iter_mut() {
250                                 assert_eq!(*slot, i as u8);
251                                 i = i.wrapping_add(1);
252                             }
253                             buf.truncate(0);
254                             consecutive_zero_length_reads = 0;
255                         }
256                         StreamResult::Dropped => break,
257                         StreamResult::Cancelled => {
258                             assert!(consecutive_zero_length_reads < 10);
259                             consecutive_zero_length_reads += 1;
260                             server.receive_stream.read(Vec::new()).await;
261                         }
262                     }
263                 }
264                 assert_eq!(i, CHUNKS * 256);
265                 server.receive_result.await.unwrap();
266             },
267         );
268     })
269     .await;
270 }
271 
272 impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>273     async fn run() -> Result<(), ()> {
274         test_tcp_ping_pong(IpAddressFamily::Ipv4).await;
275         test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(IpAddressFamily::Ipv4).await;
276         test_tcp_receive_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv4).await;
277         test_tcp_send_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv4).await;
278         test_tcp_send_drops_stream_when_remote_shutdown(IpAddressFamily::Ipv4).await;
279         test_tcp_receive_once(IpAddressFamily::Ipv4).await;
280         test_tcp_send_once(IpAddressFamily::Ipv4).await;
281         test_tcp_stream_lifetimes(IpAddressFamily::Ipv4).await;
282         test_tcp_read_cancellation(IpAddressFamily::Ipv4).await;
283 
284         if supports_ipv6() {
285             test_tcp_ping_pong(IpAddressFamily::Ipv6).await;
286             test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(IpAddressFamily::Ipv6)
287                 .await;
288             test_tcp_receive_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv6).await;
289             test_tcp_send_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv6).await;
290             test_tcp_send_drops_stream_when_remote_shutdown(IpAddressFamily::Ipv6).await;
291             test_tcp_receive_once(IpAddressFamily::Ipv6).await;
292             test_tcp_send_once(IpAddressFamily::Ipv6).await;
293             test_tcp_stream_lifetimes(IpAddressFamily::Ipv6).await;
294             test_tcp_read_cancellation(IpAddressFamily::Ipv6).await;
295         }
296         Ok(())
297     }
298 }
299 
main()300 fn main() {}
301 
302 struct Connection {
303     socket: TcpSocket,
304     receive_stream: StreamReader<u8>,
305     receive_result: FutureReader<Result<(), ErrorCode>>,
306     send_stream: StreamWriter<u8>,
307     send_result: FutureReader<Result<(), ErrorCode>>,
308 }
309 impl Connection {
new(socket: TcpSocket) -> Self310     fn new(socket: TcpSocket) -> Self {
311         let (send_stream, send_rx) = wit_stream::new();
312         let send_result = socket.send(send_rx);
313         let (receive_stream, receive_result) = socket.receive();
314         Self {
315             socket,
316             receive_stream,
317             receive_result,
318             send_stream,
319             send_result,
320         }
321     }
322 }
323 
324 /// Set up a connected pair of sockets
setup<Fut: Future<Output = ()>>( family: IpAddressFamily, body: impl FnOnce(Connection, Connection) -> Fut, )325 async fn setup<Fut: Future<Output = ()>>(
326     family: IpAddressFamily,
327     body: impl FnOnce(Connection, Connection) -> Fut,
328 ) {
329     let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
330     let listener = TcpSocket::create(family).unwrap();
331     listener.bind(bind_address).unwrap();
332     let mut accept = listener.listen().unwrap();
333     let bound_address = listener.get_local_address().unwrap();
334     let client_socket = TcpSocket::create(family).unwrap();
335     let ((), accepted_socket) = join!(
336         async {
337             client_socket.connect(bound_address).await.unwrap();
338         },
339         async { accept.next().await.unwrap() },
340     );
341 
342     body(
343         Connection::new(accepted_socket),
344         Connection::new(client_socket),
345     )
346     .await;
347 }
348