1 use futures::join;
2 use std::pin::pin;
3 use std::task::{Context, Poll, Waker};
4 use test_programs::p3::wasi::sockets::types::{
5 ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, TcpSocket,
6 };
7 use test_programs::p3::wit_stream;
8 use test_programs::sockets::supports_ipv6;
9 use wit_bindgen::{FutureReader, StreamReader, StreamResult, StreamWriter};
10
11 struct Component;
12
13 test_programs::p3::export!(Component);
14
15 /// Test basic functionality.
test_tcp_ping_pong(family: IpAddressFamily)16 async fn test_tcp_ping_pong(family: IpAddressFamily) {
17 setup(family, |mut server, mut client| async move {
18 {
19 let rest = server.send_stream.write_all(b"ping".into()).await;
20 assert!(rest.is_empty());
21 }
22 {
23 let (status, buf) = client.receive_stream.read(Vec::with_capacity(4)).await;
24 assert_eq!(status, StreamResult::Complete(4));
25 assert_eq!(buf, b"ping");
26 }
27 {
28 let rest = client.send_stream.write_all(b"pong".into()).await;
29 assert!(rest.is_empty());
30 }
31 {
32 let (status, buf) = server.receive_stream.read(Vec::with_capacity(4)).await;
33 assert_eq!(status, StreamResult::Complete(4));
34 assert_eq!(buf, b"pong");
35 }
36 })
37 .await;
38 }
39
40 /// The stream and future returned by `receive` should complete/resolve after
41 /// the connection has been shut down by the remote.
test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(family: IpAddressFamily)42 async fn test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(family: IpAddressFamily) {
43 setup(family, |server, mut client| async move {
44 drop(server);
45
46 // Wait for the shutdown signal to reach the client:
47 let (stream_result, data) = client.receive_stream.read(Vec::with_capacity(1)).await;
48 assert_eq!(data.len(), 0);
49 assert_eq!(stream_result, StreamResult::Dropped);
50 client.receive_result.await.unwrap();
51 })
52 .await;
53 }
54
55 /// The future returned by `receive` should resolve once the companion stream
56 /// has been dropped. Regardless of whether there was still data pending.
test_tcp_receive_future_should_resolve_when_stream_dropped(family: IpAddressFamily)57 async fn test_tcp_receive_future_should_resolve_when_stream_dropped(family: IpAddressFamily) {
58 setup(family, |mut server, client| async move {
59 {
60 let rest = server.send_stream.write_all(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.".into()).await;
61 assert!(rest.is_empty());
62 }
63 {
64 let Connection { mut receive_stream, receive_result, .. } = client;
65
66 // Wait for the data to be ready:
67 receive_stream.next().await.unwrap();
68 drop(receive_stream);
69
70 // Dropping the stream should've caused the future to resolve even
71 // though there was still data pending:
72 receive_result.await.unwrap();
73 }
74 }).await;
75 }
76
77 /// The future returned by `send` should resolve after the input stream is dropped.
test_tcp_send_future_should_resolve_when_stream_dropped(family: IpAddressFamily)78 async fn test_tcp_send_future_should_resolve_when_stream_dropped(family: IpAddressFamily) {
79 setup(family, |_server, client| async move {
80 let Connection {
81 send_stream,
82 send_result,
83 ..
84 } = client;
85 drop(send_stream);
86 send_result.await.unwrap();
87 })
88 .await;
89 }
90
91 /// `send` should drop the input stream when the connection is shut down by the remote.
test_tcp_send_drops_stream_when_remote_shutdown(family: IpAddressFamily)92 async fn test_tcp_send_drops_stream_when_remote_shutdown(family: IpAddressFamily) {
93 setup(family, |server, mut client| async move {
94 drop(server);
95
96 // Give it a few tries for the shutdown signal to reach the client:
97 loop {
98 let stream_result = client.send_stream.write(b"undeliverable".into()).await.0;
99 if stream_result == StreamResult::Dropped {
100 break;
101 }
102 }
103
104 let result = client.send_result.await;
105 assert!(
106 matches!(
107 result,
108 Err(ErrorCode::ConnectionBroken | ErrorCode::ConnectionReset)
109 ),
110 "unexpected error {result:?}",
111 );
112 })
113 .await;
114 }
115
116 /// `receive` may be called successfully at most once.
test_tcp_receive_once(family: IpAddressFamily)117 async fn test_tcp_receive_once(family: IpAddressFamily) {
118 setup(family, |mut server, client| async move {
119 // Give the client some potential data to _hopefully never_ read.
120 {
121 let rest = server.send_stream.write_all(b"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.".into()).await;
122 assert!(rest.is_empty());
123 }
124
125 // FYI, the first call to `receive` is part of the `setup` code, so every
126 // `receive` in here should fail.
127 for _ in 0..3 {
128 let (mut reader, future) = client.socket.receive();
129
130 let (stream_result, data) = reader.read(Vec::with_capacity(10)).await;
131 assert_eq!(data.len(), 0);
132 assert_eq!(stream_result, StreamResult::Dropped);
133 assert!(matches!(future.await, Err(ErrorCode::InvalidState)));
134 }
135 })
136 .await;
137 }
138
139 /// `send` may be called successfully at most once.
test_tcp_send_once(family: IpAddressFamily)140 async fn test_tcp_send_once(family: IpAddressFamily) {
141 setup(family, |_server, client| async move {
142 // FYI, the first call to `send` is part of the `setup` code, so every
143 // `send` in here should fail.
144 for _ in 0..3 {
145 let (mut writer, send_rx) = wit_stream::new();
146 let future = client.socket.send(send_rx);
147
148 const DATA: &[u8] = b"undeliverable";
149 let (stream_result, rest) = writer.write(DATA.into()).await;
150 assert_eq!(rest.into_vec(), DATA);
151 assert_eq!(stream_result, StreamResult::Dropped);
152 assert!(matches!(future.await, Err(ErrorCode::InvalidState)));
153 }
154 })
155 .await;
156 }
157
158 /// The streams and futures returned by `send` and `receive` should remain
159 /// operational even after the socket that spawned them has been dropped.
test_tcp_stream_lifetimes(family: IpAddressFamily)160 async fn test_tcp_stream_lifetimes(family: IpAddressFamily) {
161 setup(family, |server, client| async move {
162 let Connection {
163 socket: server_socket,
164 send_stream: mut server_send_stream,
165 receive_stream: server_receive_stream,
166 send_result: server_send_result,
167 receive_result: server_receive_result,
168 } = server;
169 let Connection {
170 socket: client_socket,
171 send_stream: mut client_send_stream,
172 receive_stream: client_receive_stream,
173 send_result: client_send_result,
174 receive_result: client_receive_result,
175 } = client;
176
177 // Drop the parent sockets:
178 drop(server_socket);
179 drop(client_socket);
180
181 {
182 let rest = server_send_stream.write_all(b"ping".into()).await;
183 assert!(rest.is_empty());
184 drop(server_send_stream);
185 server_send_result.await.unwrap();
186 }
187 {
188 let data = client_receive_stream.collect().await;
189 assert_eq!(data, b"ping");
190 client_receive_result.await.unwrap();
191 }
192 {
193 let rest = client_send_stream.write_all(b"pong".into()).await;
194 assert!(rest.is_empty());
195 drop(client_send_stream);
196 client_send_result.await.unwrap();
197 }
198 {
199 let data = server_receive_stream.collect().await;
200 assert_eq!(data, b"pong");
201 server_receive_result.await.unwrap();
202 }
203 })
204 .await;
205 }
206
207 /// Model a situation where there's a continuous stream of data coming into the
208 /// guest from one side and the other side is reading in chunks but also
209 /// cancelling reads occasionally. Should receive the complete stream of data
210 /// into the result.
test_tcp_read_cancellation(family: IpAddressFamily)211 async fn test_tcp_read_cancellation(family: IpAddressFamily) {
212 // Send 2M of data in 256-byte chunks.
213 const CHUNKS: usize = (2 << 20) / 256;
214 let mut data = [0; 256];
215 for (i, slot) in data.iter_mut().enumerate() {
216 *slot = i as u8;
217 }
218
219 setup(family, |mut server, mut client| async move {
220 // Minimize the local send buffer:
221 client.socket.set_send_buffer_size(1024).unwrap();
222
223 join!(
224 async {
225 for _ in 0..CHUNKS {
226 let ret = client.send_stream.write_all(data.to_vec()).await;
227 assert!(ret.is_empty());
228 }
229 drop(client.send_stream);
230 },
231 async {
232 let mut buf = Vec::with_capacity(1024);
233 let mut i = 0_usize;
234 let mut consecutive_zero_length_reads = 0;
235 loop {
236 assert!(buf.is_empty());
237 let (status, b) = {
238 let mut fut = pin!(server.receive_stream.read(buf));
239 let mut cx = Context::from_waker(Waker::noop());
240 match fut.as_mut().poll(&mut cx) {
241 Poll::Ready(pair) => pair,
242 Poll::Pending => fut.cancel(),
243 }
244 };
245 buf = b;
246 match status {
247 StreamResult::Complete(n) => {
248 assert_eq!(buf.len(), n);
249 for slot in buf.iter_mut() {
250 assert_eq!(*slot, i as u8);
251 i = i.wrapping_add(1);
252 }
253 buf.truncate(0);
254 consecutive_zero_length_reads = 0;
255 }
256 StreamResult::Dropped => break,
257 StreamResult::Cancelled => {
258 assert!(consecutive_zero_length_reads < 10);
259 consecutive_zero_length_reads += 1;
260 server.receive_stream.read(Vec::new()).await;
261 }
262 }
263 }
264 assert_eq!(i, CHUNKS * 256);
265 server.receive_result.await.unwrap();
266 },
267 );
268 })
269 .await;
270 }
271
272 impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>273 async fn run() -> Result<(), ()> {
274 test_tcp_ping_pong(IpAddressFamily::Ipv4).await;
275 test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(IpAddressFamily::Ipv4).await;
276 test_tcp_receive_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv4).await;
277 test_tcp_send_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv4).await;
278 test_tcp_send_drops_stream_when_remote_shutdown(IpAddressFamily::Ipv4).await;
279 test_tcp_receive_once(IpAddressFamily::Ipv4).await;
280 test_tcp_send_once(IpAddressFamily::Ipv4).await;
281 test_tcp_stream_lifetimes(IpAddressFamily::Ipv4).await;
282 test_tcp_read_cancellation(IpAddressFamily::Ipv4).await;
283
284 if supports_ipv6() {
285 test_tcp_ping_pong(IpAddressFamily::Ipv6).await;
286 test_tcp_receive_stream_should_be_dropped_by_remote_shutdown(IpAddressFamily::Ipv6)
287 .await;
288 test_tcp_receive_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv6).await;
289 test_tcp_send_future_should_resolve_when_stream_dropped(IpAddressFamily::Ipv6).await;
290 test_tcp_send_drops_stream_when_remote_shutdown(IpAddressFamily::Ipv6).await;
291 test_tcp_receive_once(IpAddressFamily::Ipv6).await;
292 test_tcp_send_once(IpAddressFamily::Ipv6).await;
293 test_tcp_stream_lifetimes(IpAddressFamily::Ipv6).await;
294 test_tcp_read_cancellation(IpAddressFamily::Ipv6).await;
295 }
296 Ok(())
297 }
298 }
299
main()300 fn main() {}
301
302 struct Connection {
303 socket: TcpSocket,
304 receive_stream: StreamReader<u8>,
305 receive_result: FutureReader<Result<(), ErrorCode>>,
306 send_stream: StreamWriter<u8>,
307 send_result: FutureReader<Result<(), ErrorCode>>,
308 }
309 impl Connection {
new(socket: TcpSocket) -> Self310 fn new(socket: TcpSocket) -> Self {
311 let (send_stream, send_rx) = wit_stream::new();
312 let send_result = socket.send(send_rx);
313 let (receive_stream, receive_result) = socket.receive();
314 Self {
315 socket,
316 receive_stream,
317 receive_result,
318 send_stream,
319 send_result,
320 }
321 }
322 }
323
324 /// Set up a connected pair of sockets
setup<Fut: Future<Output = ()>>( family: IpAddressFamily, body: impl FnOnce(Connection, Connection) -> Fut, )325 async fn setup<Fut: Future<Output = ()>>(
326 family: IpAddressFamily,
327 body: impl FnOnce(Connection, Connection) -> Fut,
328 ) {
329 let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
330 let listener = TcpSocket::create(family).unwrap();
331 listener.bind(bind_address).unwrap();
332 let mut accept = listener.listen().unwrap();
333 let bound_address = listener.get_local_address().unwrap();
334 let client_socket = TcpSocket::create(family).unwrap();
335 let ((), accepted_socket) = join!(
336 async {
337 client_socket.connect(bound_address).await.unwrap();
338 },
339 async { accept.next().await.unwrap() },
340 );
341
342 body(
343 Connection::new(accepted_socket),
344 Connection::new(client_socket),
345 )
346 .await;
347 }
348