xref: /webrtc/sctp/src/stream/stream_test.rs (revision 39ea30b8)
1 use super::*;
2 use std::sync::atomic::{AtomicU32, Ordering};
3 use std::sync::Arc;
4 use tokio::io::AsyncReadExt;
5 use tokio::io::AsyncWriteExt;
6 
7 #[test]
test_stream_buffered_amount() -> Result<()>8 fn test_stream_buffered_amount() -> Result<()> {
9     let s = Stream::default();
10 
11     assert_eq!(s.buffered_amount(), 0);
12     assert_eq!(s.buffered_amount_low_threshold(), 0);
13 
14     s.buffered_amount.store(8192, Ordering::SeqCst);
15     s.set_buffered_amount_low_threshold(2048);
16     assert_eq!(s.buffered_amount(), 8192, "unexpected bufferedAmount");
17     assert_eq!(
18         s.buffered_amount_low_threshold(),
19         2048,
20         "unexpected threshold"
21     );
22 
23     Ok(())
24 }
25 
26 #[tokio::test]
test_stream_amount_on_buffered_amount_low() -> Result<()>27 async fn test_stream_amount_on_buffered_amount_low() -> Result<()> {
28     let s = Stream::default();
29 
30     s.buffered_amount.store(4096, Ordering::SeqCst);
31     s.set_buffered_amount_low_threshold(2048);
32 
33     let n_cbs = Arc::new(AtomicU32::new(0));
34     let n_cbs2 = n_cbs.clone();
35 
36     s.on_buffered_amount_low(Box::new(move || {
37         n_cbs2.fetch_add(1, Ordering::SeqCst);
38         Box::pin(async {})
39     }));
40 
41     // Negative value should be ignored (by design)
42     s.on_buffer_released(-32).await; // bufferedAmount = 3072
43     assert_eq!(s.buffered_amount(), 4096, "unexpected bufferedAmount");
44     assert_eq!(n_cbs.load(Ordering::SeqCst), 0, "callback count mismatch");
45 
46     // Above to above, no callback
47     s.on_buffer_released(1024).await; // bufferedAmount = 3072
48     assert_eq!(s.buffered_amount(), 3072, "unexpected bufferedAmount");
49     assert_eq!(n_cbs.load(Ordering::SeqCst), 0, "callback count mismatch");
50 
51     // Above to equal, callback should be made
52     s.on_buffer_released(1024).await; // bufferedAmount = 2048
53     assert_eq!(s.buffered_amount(), 2048, "unexpected bufferedAmount");
54     assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
55 
56     // Eaual to below, no callback
57     s.on_buffer_released(1024).await; // bufferedAmount = 1024
58     assert_eq!(s.buffered_amount(), 1024, "unexpected bufferedAmount");
59     assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
60 
61     // Blow to below, no callback
62     s.on_buffer_released(1024).await; // bufferedAmount = 0
63     assert_eq!(s.buffered_amount(), 0, "unexpected bufferedAmount");
64     assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
65 
66     // Capped at 0, no callback
67     s.on_buffer_released(1024).await; // bufferedAmount = 0
68     assert_eq!(s.buffered_amount(), 0, "unexpected bufferedAmount");
69     assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
70 
71     Ok(())
72 }
73 
74 #[tokio::test]
test_stream() -> std::result::Result<(), io::Error>75 async fn test_stream() -> std::result::Result<(), io::Error> {
76     let s = Stream::new(
77         "test_poll_stream".to_owned(),
78         0,
79         4096,
80         Arc::new(AtomicU32::new(4096)),
81         Arc::new(AtomicU8::new(AssociationState::Established as u8)),
82         None,
83         Arc::new(PendingQueue::new()),
84     );
85 
86     // getters
87     assert_eq!(s.stream_identifier(), 0);
88     assert_eq!(s.buffered_amount(), 0);
89     assert_eq!(s.buffered_amount_low_threshold(), 0);
90     assert_eq!(s.get_num_bytes_in_reassembly_queue().await, 0);
91 
92     // setters
93     s.set_default_payload_type(PayloadProtocolIdentifier::Binary);
94     s.set_reliability_params(true, ReliabilityType::Reliable, 0);
95 
96     // write
97     let n = s.write(&Bytes::from("Hello ")).await?;
98     assert_eq!(n, 6);
99     assert_eq!(s.buffered_amount(), 6);
100     let n = s
101         .write_sctp(&Bytes::from("world"), PayloadProtocolIdentifier::Binary)
102         .await?;
103     assert_eq!(n, 5);
104     assert_eq!(s.buffered_amount(), 11);
105 
106     // async read
107     //  1. pretend that we've received a chunk
108     s.handle_data(ChunkPayloadData {
109         unordered: true,
110         beginning_fragment: true,
111         ending_fragment: true,
112         user_data: Bytes::from_static(&[0, 1, 2, 3, 4]),
113         payload_type: PayloadProtocolIdentifier::Binary,
114         ..Default::default()
115     })
116     .await;
117     //  2. read it
118     let mut buf = [0; 5];
119     s.read(&mut buf).await?;
120     assert_eq!(buf, [0, 1, 2, 3, 4]);
121 
122     // shutdown write
123     s.shutdown(Shutdown::Write).await?;
124     // write must fail
125     assert!(s.write(&Bytes::from("error")).await.is_err());
126     // read should continue working
127     s.handle_data(ChunkPayloadData {
128         unordered: true,
129         beginning_fragment: true,
130         ending_fragment: true,
131         user_data: Bytes::from_static(&[5, 6, 7, 8, 9]),
132         payload_type: PayloadProtocolIdentifier::Binary,
133         ..Default::default()
134     })
135     .await;
136     let mut buf = [0; 5];
137     s.read(&mut buf).await?;
138     assert_eq!(buf, [5, 6, 7, 8, 9]);
139 
140     // shutdown read
141     s.shutdown(Shutdown::Read).await?;
142     // read must return 0
143     assert_eq!(s.read(&mut buf).await, Ok(0));
144 
145     Ok(())
146 }
147 
148 #[tokio::test]
test_poll_stream() -> std::result::Result<(), io::Error>149 async fn test_poll_stream() -> std::result::Result<(), io::Error> {
150     let s = Arc::new(Stream::new(
151         "test_poll_stream".to_owned(),
152         0,
153         4096,
154         Arc::new(AtomicU32::new(4096)),
155         Arc::new(AtomicU8::new(AssociationState::Established as u8)),
156         None,
157         Arc::new(PendingQueue::new()),
158     ));
159     let mut poll_stream = PollStream::new(s.clone());
160 
161     // getters
162     assert_eq!(poll_stream.stream_identifier(), 0);
163     assert_eq!(poll_stream.buffered_amount(), 0);
164     assert_eq!(poll_stream.buffered_amount_low_threshold(), 0);
165     assert_eq!(poll_stream.get_num_bytes_in_reassembly_queue().await, 0);
166 
167     // async write
168     let n = poll_stream.write(&[1, 2, 3]).await?;
169     assert_eq!(n, 3);
170     poll_stream.flush().await?;
171     assert_eq!(poll_stream.buffered_amount(), 3);
172 
173     // async read
174     //  1. pretend that we've received a chunk
175     let sc = s.clone();
176     sc.handle_data(ChunkPayloadData {
177         unordered: true,
178         beginning_fragment: true,
179         ending_fragment: true,
180         user_data: Bytes::from_static(&[0, 1, 2, 3, 4]),
181         payload_type: PayloadProtocolIdentifier::Binary,
182         ..Default::default()
183     })
184     .await;
185     //  2. read it
186     let mut buf = [0; 5];
187     poll_stream.read_exact(&mut buf).await?;
188     assert_eq!(buf, [0, 1, 2, 3, 4]);
189 
190     // shutdown write
191     poll_stream.shutdown().await?;
192     // write must fail
193     assert!(poll_stream.write(&[1, 2, 3]).await.is_err());
194     // read should continue working
195     sc.handle_data(ChunkPayloadData {
196         unordered: true,
197         beginning_fragment: true,
198         ending_fragment: true,
199         user_data: Bytes::from_static(&[5, 6, 7, 8, 9]),
200         payload_type: PayloadProtocolIdentifier::Binary,
201         ..Default::default()
202     })
203     .await;
204     let mut buf = [0; 5];
205     poll_stream.read_exact(&mut buf).await?;
206     assert_eq!(buf, [5, 6, 7, 8, 9]);
207 
208     // misc.
209     let clone = poll_stream.clone();
210     assert_eq!(clone.stream_identifier(), poll_stream.stream_identifier());
211 
212     Ok(())
213 }
214