1 use super::*;
2 use std::sync::atomic::{AtomicU32, Ordering};
3 use std::sync::Arc;
4 use tokio::io::AsyncReadExt;
5 use tokio::io::AsyncWriteExt;
6
7 #[test]
test_stream_buffered_amount() -> Result<()>8 fn test_stream_buffered_amount() -> Result<()> {
9 let s = Stream::default();
10
11 assert_eq!(s.buffered_amount(), 0);
12 assert_eq!(s.buffered_amount_low_threshold(), 0);
13
14 s.buffered_amount.store(8192, Ordering::SeqCst);
15 s.set_buffered_amount_low_threshold(2048);
16 assert_eq!(s.buffered_amount(), 8192, "unexpected bufferedAmount");
17 assert_eq!(
18 s.buffered_amount_low_threshold(),
19 2048,
20 "unexpected threshold"
21 );
22
23 Ok(())
24 }
25
26 #[tokio::test]
test_stream_amount_on_buffered_amount_low() -> Result<()>27 async fn test_stream_amount_on_buffered_amount_low() -> Result<()> {
28 let s = Stream::default();
29
30 s.buffered_amount.store(4096, Ordering::SeqCst);
31 s.set_buffered_amount_low_threshold(2048);
32
33 let n_cbs = Arc::new(AtomicU32::new(0));
34 let n_cbs2 = n_cbs.clone();
35
36 s.on_buffered_amount_low(Box::new(move || {
37 n_cbs2.fetch_add(1, Ordering::SeqCst);
38 Box::pin(async {})
39 }));
40
41 // Negative value should be ignored (by design)
42 s.on_buffer_released(-32).await; // bufferedAmount = 3072
43 assert_eq!(s.buffered_amount(), 4096, "unexpected bufferedAmount");
44 assert_eq!(n_cbs.load(Ordering::SeqCst), 0, "callback count mismatch");
45
46 // Above to above, no callback
47 s.on_buffer_released(1024).await; // bufferedAmount = 3072
48 assert_eq!(s.buffered_amount(), 3072, "unexpected bufferedAmount");
49 assert_eq!(n_cbs.load(Ordering::SeqCst), 0, "callback count mismatch");
50
51 // Above to equal, callback should be made
52 s.on_buffer_released(1024).await; // bufferedAmount = 2048
53 assert_eq!(s.buffered_amount(), 2048, "unexpected bufferedAmount");
54 assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
55
56 // Eaual to below, no callback
57 s.on_buffer_released(1024).await; // bufferedAmount = 1024
58 assert_eq!(s.buffered_amount(), 1024, "unexpected bufferedAmount");
59 assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
60
61 // Blow to below, no callback
62 s.on_buffer_released(1024).await; // bufferedAmount = 0
63 assert_eq!(s.buffered_amount(), 0, "unexpected bufferedAmount");
64 assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
65
66 // Capped at 0, no callback
67 s.on_buffer_released(1024).await; // bufferedAmount = 0
68 assert_eq!(s.buffered_amount(), 0, "unexpected bufferedAmount");
69 assert_eq!(n_cbs.load(Ordering::SeqCst), 1, "callback count mismatch");
70
71 Ok(())
72 }
73
74 #[tokio::test]
test_stream() -> std::result::Result<(), io::Error>75 async fn test_stream() -> std::result::Result<(), io::Error> {
76 let s = Stream::new(
77 "test_poll_stream".to_owned(),
78 0,
79 4096,
80 Arc::new(AtomicU32::new(4096)),
81 Arc::new(AtomicU8::new(AssociationState::Established as u8)),
82 None,
83 Arc::new(PendingQueue::new()),
84 );
85
86 // getters
87 assert_eq!(s.stream_identifier(), 0);
88 assert_eq!(s.buffered_amount(), 0);
89 assert_eq!(s.buffered_amount_low_threshold(), 0);
90 assert_eq!(s.get_num_bytes_in_reassembly_queue().await, 0);
91
92 // setters
93 s.set_default_payload_type(PayloadProtocolIdentifier::Binary);
94 s.set_reliability_params(true, ReliabilityType::Reliable, 0);
95
96 // write
97 let n = s.write(&Bytes::from("Hello ")).await?;
98 assert_eq!(n, 6);
99 assert_eq!(s.buffered_amount(), 6);
100 let n = s
101 .write_sctp(&Bytes::from("world"), PayloadProtocolIdentifier::Binary)
102 .await?;
103 assert_eq!(n, 5);
104 assert_eq!(s.buffered_amount(), 11);
105
106 // async read
107 // 1. pretend that we've received a chunk
108 s.handle_data(ChunkPayloadData {
109 unordered: true,
110 beginning_fragment: true,
111 ending_fragment: true,
112 user_data: Bytes::from_static(&[0, 1, 2, 3, 4]),
113 payload_type: PayloadProtocolIdentifier::Binary,
114 ..Default::default()
115 })
116 .await;
117 // 2. read it
118 let mut buf = [0; 5];
119 s.read(&mut buf).await?;
120 assert_eq!(buf, [0, 1, 2, 3, 4]);
121
122 // shutdown write
123 s.shutdown(Shutdown::Write).await?;
124 // write must fail
125 assert!(s.write(&Bytes::from("error")).await.is_err());
126 // read should continue working
127 s.handle_data(ChunkPayloadData {
128 unordered: true,
129 beginning_fragment: true,
130 ending_fragment: true,
131 user_data: Bytes::from_static(&[5, 6, 7, 8, 9]),
132 payload_type: PayloadProtocolIdentifier::Binary,
133 ..Default::default()
134 })
135 .await;
136 let mut buf = [0; 5];
137 s.read(&mut buf).await?;
138 assert_eq!(buf, [5, 6, 7, 8, 9]);
139
140 // shutdown read
141 s.shutdown(Shutdown::Read).await?;
142 // read must return 0
143 assert_eq!(s.read(&mut buf).await, Ok(0));
144
145 Ok(())
146 }
147
148 #[tokio::test]
test_poll_stream() -> std::result::Result<(), io::Error>149 async fn test_poll_stream() -> std::result::Result<(), io::Error> {
150 let s = Arc::new(Stream::new(
151 "test_poll_stream".to_owned(),
152 0,
153 4096,
154 Arc::new(AtomicU32::new(4096)),
155 Arc::new(AtomicU8::new(AssociationState::Established as u8)),
156 None,
157 Arc::new(PendingQueue::new()),
158 ));
159 let mut poll_stream = PollStream::new(s.clone());
160
161 // getters
162 assert_eq!(poll_stream.stream_identifier(), 0);
163 assert_eq!(poll_stream.buffered_amount(), 0);
164 assert_eq!(poll_stream.buffered_amount_low_threshold(), 0);
165 assert_eq!(poll_stream.get_num_bytes_in_reassembly_queue().await, 0);
166
167 // async write
168 let n = poll_stream.write(&[1, 2, 3]).await?;
169 assert_eq!(n, 3);
170 poll_stream.flush().await?;
171 assert_eq!(poll_stream.buffered_amount(), 3);
172
173 // async read
174 // 1. pretend that we've received a chunk
175 let sc = s.clone();
176 sc.handle_data(ChunkPayloadData {
177 unordered: true,
178 beginning_fragment: true,
179 ending_fragment: true,
180 user_data: Bytes::from_static(&[0, 1, 2, 3, 4]),
181 payload_type: PayloadProtocolIdentifier::Binary,
182 ..Default::default()
183 })
184 .await;
185 // 2. read it
186 let mut buf = [0; 5];
187 poll_stream.read_exact(&mut buf).await?;
188 assert_eq!(buf, [0, 1, 2, 3, 4]);
189
190 // shutdown write
191 poll_stream.shutdown().await?;
192 // write must fail
193 assert!(poll_stream.write(&[1, 2, 3]).await.is_err());
194 // read should continue working
195 sc.handle_data(ChunkPayloadData {
196 unordered: true,
197 beginning_fragment: true,
198 ending_fragment: true,
199 user_data: Bytes::from_static(&[5, 6, 7, 8, 9]),
200 payload_type: PayloadProtocolIdentifier::Binary,
201 ..Default::default()
202 })
203 .await;
204 let mut buf = [0; 5];
205 poll_stream.read_exact(&mut buf).await?;
206 assert_eq!(buf, [5, 6, 7, 8, 9]);
207
208 // misc.
209 let clone = poll_stream.clone();
210 assert_eq!(clone.stream_identifier(), poll_stream.stream_identifier());
211
212 Ok(())
213 }
214