xref: /webrtc/sctp/src/queue/reassembly_queue.rs (revision d9c41ff0)
1 use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
2 use crate::util::*;
3 
4 use crate::error::{Error, Result};
5 
6 use std::cmp::Ordering;
7 
8 fn sort_chunks_by_tsn(c: &mut [ChunkPayloadData]) {
9     c.sort_by(|a, b| {
10         if sna32lt(a.tsn, b.tsn) {
11             Ordering::Less
12         } else {
13             Ordering::Greater
14         }
15     });
16 }
17 
18 fn sort_chunks_by_ssn(c: &mut [ChunkSet]) {
19     c.sort_by(|a, b| {
20         if sna16lt(a.ssn, b.ssn) {
21             Ordering::Less
22         } else {
23             Ordering::Greater
24         }
25     });
26 }
27 
28 /// chunkSet is a set of chunks that share the same SSN
29 #[derive(Debug, Clone)]
30 pub(crate) struct ChunkSet {
31     /// used only with the ordered chunks
32     pub(crate) ssn: u16,
33     pub(crate) ppi: PayloadProtocolIdentifier,
34     pub(crate) chunks: Vec<ChunkPayloadData>,
35 }
36 
37 impl ChunkSet {
38     pub(crate) fn new(ssn: u16, ppi: PayloadProtocolIdentifier) -> Self {
39         ChunkSet {
40             ssn,
41             ppi,
42             chunks: vec![],
43         }
44     }
45 
46     pub(crate) fn push(&mut self, chunk: ChunkPayloadData) -> bool {
47         // check if dup
48         for c in &self.chunks {
49             if c.tsn == chunk.tsn {
50                 return false;
51             }
52         }
53 
54         // append and sort
55         self.chunks.push(chunk);
56         sort_chunks_by_tsn(&mut self.chunks);
57 
58         // Check if we now have a complete set
59         self.is_complete()
60     }
61 
62     pub(crate) fn is_complete(&self) -> bool {
63         // Condition for complete set
64         //   0. Has at least one chunk.
65         //   1. Begins with beginningFragment set to true
66         //   2. Ends with endingFragment set to true
67         //   3. TSN monotinically increase by 1 from beginning to end
68 
69         // 0.
70         let n_chunks = self.chunks.len();
71         if n_chunks == 0 {
72             return false;
73         }
74 
75         // 1.
76         if !self.chunks[0].beginning_fragment {
77             return false;
78         }
79 
80         // 2.
81         if !self.chunks[n_chunks - 1].ending_fragment {
82             return false;
83         }
84 
85         // 3.
86         let mut last_tsn = 0u32;
87         for (i, c) in self.chunks.iter().enumerate() {
88             if i > 0 {
89                 // Fragments must have contiguous TSN
90                 // From RFC 4960 Section 3.3.1:
91                 //   When a user message is fragmented into multiple chunks, the TSNs are
92                 //   used by the receiver to reassemble the message.  This means that the
93                 //   TSNs for each fragment of a fragmented user message MUST be strictly
94                 //   sequential.
95                 if c.tsn != last_tsn + 1 {
96                     // mid or end fragment is missing
97                     return false;
98                 }
99             }
100 
101             last_tsn = c.tsn;
102         }
103 
104         true
105     }
106 }
107 
108 #[derive(Default, Debug)]
109 pub(crate) struct ReassemblyQueue {
110     pub(crate) si: u16,
111     pub(crate) next_ssn: u16,
112     /// expected SSN for next ordered chunk
113     pub(crate) ordered: Vec<ChunkSet>,
114     pub(crate) unordered: Vec<ChunkSet>,
115     pub(crate) unordered_chunks: Vec<ChunkPayloadData>,
116     pub(crate) n_bytes: usize,
117 }
118 
119 impl ReassemblyQueue {
120     /// From RFC 4960 Sec 6.5:
121     ///   The Stream Sequence Number in all the streams MUST start from 0 when
122     ///   the association is Established.  Also, when the Stream Sequence
123     ///   Number reaches the value 65535 the next Stream Sequence Number MUST
124     ///   be set to 0.
125     pub(crate) fn new(si: u16) -> Self {
126         ReassemblyQueue {
127             si,
128             next_ssn: 0, // From RFC 4960 Sec 6.5:
129             ordered: vec![],
130             unordered: vec![],
131             unordered_chunks: vec![],
132             n_bytes: 0,
133         }
134     }
135 
136     pub(crate) fn push(&mut self, chunk: ChunkPayloadData) -> bool {
137         if chunk.stream_identifier != self.si {
138             return false;
139         }
140 
141         if chunk.unordered {
142             // First, insert into unordered_chunks array
143             //atomic.AddUint64(&r.n_bytes, uint64(len(chunk.userData)))
144             self.n_bytes += chunk.user_data.len();
145             self.unordered_chunks.push(chunk);
146             sort_chunks_by_tsn(&mut self.unordered_chunks);
147 
148             // Scan unordered_chunks that are contiguous (in TSN)
149             // If found, append the complete set to the unordered array
150             if let Some(cset) = self.find_complete_unordered_chunk_set() {
151                 self.unordered.push(cset);
152                 return true;
153             }
154 
155             false
156         } else {
157             // This is an ordered chunk
158             if sna16lt(chunk.stream_sequence_number, self.next_ssn) {
159                 return false;
160             }
161 
162             self.n_bytes += chunk.user_data.len();
163 
164             // Check if a chunkSet with the SSN already exists
165             for s in &mut self.ordered {
166                 if s.ssn == chunk.stream_sequence_number {
167                     return s.push(chunk);
168                 }
169             }
170 
171             // If not found, create a new chunkSet
172             let mut cset = ChunkSet::new(chunk.stream_sequence_number, chunk.payload_type);
173             let unordered = chunk.unordered;
174             let ok = cset.push(chunk);
175             self.ordered.push(cset);
176             if !unordered {
177                 sort_chunks_by_ssn(&mut self.ordered);
178             }
179 
180             ok
181         }
182     }
183 
184     pub(crate) fn find_complete_unordered_chunk_set(&mut self) -> Option<ChunkSet> {
185         let mut start_idx = -1isize;
186         let mut n_chunks = 0usize;
187         let mut last_tsn = 0u32;
188         let mut found = false;
189 
190         for (i, c) in self.unordered_chunks.iter().enumerate() {
191             // seek beigining
192             if c.beginning_fragment {
193                 start_idx = i as isize;
194                 n_chunks = 1;
195                 last_tsn = c.tsn;
196 
197                 if c.ending_fragment {
198                     found = true;
199                     break;
200                 }
201                 continue;
202             }
203 
204             if start_idx < 0 {
205                 continue;
206             }
207 
208             // Check if contiguous in TSN
209             if c.tsn != last_tsn + 1 {
210                 start_idx = -1;
211                 continue;
212             }
213 
214             last_tsn = c.tsn;
215             n_chunks += 1;
216 
217             if c.ending_fragment {
218                 found = true;
219                 break;
220             }
221         }
222 
223         if !found {
224             return None;
225         }
226 
227         // Extract the range of chunks
228         let chunks: Vec<ChunkPayloadData> = self
229             .unordered_chunks
230             .drain(start_idx as usize..(start_idx as usize) + n_chunks)
231             .collect();
232 
233         let mut chunk_set = ChunkSet::new(0, chunks[0].payload_type);
234         chunk_set.chunks = chunks;
235 
236         Some(chunk_set)
237     }
238 
239     pub(crate) fn is_readable(&self) -> bool {
240         // Check unordered first
241         if !self.unordered.is_empty() {
242             // The chunk sets in r.unordered should all be complete.
243             return true;
244         }
245 
246         // Check ordered sets
247         if !self.ordered.is_empty() {
248             let cset = &self.ordered[0];
249             if cset.is_complete() && sna16lte(cset.ssn, self.next_ssn) {
250                 return true;
251             }
252         }
253         false
254     }
255 
256     pub(crate) fn read(&mut self, buf: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> {
257         // Check unordered first
258         let cset = if !self.unordered.is_empty() {
259             self.unordered.remove(0)
260         } else if !self.ordered.is_empty() {
261             // Now, check ordered
262             let cset = &self.ordered[0];
263             if !cset.is_complete() {
264                 return Err(Error::ErrTryAgain);
265             }
266             if sna16gt(cset.ssn, self.next_ssn) {
267                 return Err(Error::ErrTryAgain);
268             }
269             if cset.ssn == self.next_ssn {
270                 // From RFC 4960 Sec 6.5:
271                 self.next_ssn = self.next_ssn.wrapping_add(1);
272             }
273             self.ordered.remove(0)
274         } else {
275             return Err(Error::ErrTryAgain);
276         };
277 
278         // Concat all fragments into the buffer
279         let mut n_written = 0;
280         let mut err = None;
281         for c in &cset.chunks {
282             let to_copy = c.user_data.len();
283             self.subtract_num_bytes(to_copy);
284             if err.is_none() {
285                 let n = std::cmp::min(to_copy, buf.len() - n_written);
286                 buf[n_written..n_written + n].copy_from_slice(&c.user_data[..n]);
287                 n_written += n;
288                 if n < to_copy {
289                     err = Some(Error::ErrShortBuffer);
290                 }
291             }
292         }
293 
294         if let Some(err) = err {
295             Err(err)
296         } else {
297             Ok((n_written, cset.ppi))
298         }
299     }
300 
301     /// Use last_ssn to locate a chunkSet then remove it if the set has
302     /// not been complete
303     pub(crate) fn forward_tsn_for_ordered(&mut self, last_ssn: u16) {
304         let num_bytes = self
305             .ordered
306             .iter()
307             .filter(|s| sna16lte(s.ssn, last_ssn) && !s.is_complete())
308             .fold(0, |n, s| {
309                 n + s.chunks.iter().fold(0, |acc, c| acc + c.user_data.len())
310             });
311         self.subtract_num_bytes(num_bytes);
312 
313         self.ordered
314             .retain(|s| !sna16lte(s.ssn, last_ssn) || s.is_complete());
315 
316         // Finally, forward next_ssn
317         if sna16lte(self.next_ssn, last_ssn) {
318             self.next_ssn = last_ssn.wrapping_add(1);
319         }
320     }
321 
322     /// Remove all fragments in the unordered sets that contains chunks
323     /// equal to or older than `new_cumulative_tsn`.
324     /// We know all sets in the r.unordered are complete ones.
325     /// Just remove chunks that are equal to or older than new_cumulative_tsn
326     /// from the unordered_chunks
327     pub(crate) fn forward_tsn_for_unordered(&mut self, new_cumulative_tsn: u32) {
328         let mut last_idx: isize = -1;
329         for (i, c) in self.unordered_chunks.iter().enumerate() {
330             if sna32gt(c.tsn, new_cumulative_tsn) {
331                 break;
332             }
333             last_idx = i as isize;
334         }
335         if last_idx >= 0 {
336             for i in 0..(last_idx + 1) as usize {
337                 self.subtract_num_bytes(self.unordered_chunks[i].user_data.len());
338             }
339             self.unordered_chunks.drain(..(last_idx + 1) as usize);
340         }
341     }
342 
343     pub(crate) fn subtract_num_bytes(&mut self, n_bytes: usize) {
344         if self.n_bytes >= n_bytes {
345             self.n_bytes -= n_bytes;
346         } else {
347             self.n_bytes = 0;
348         }
349     }
350 
351     pub(crate) fn get_num_bytes(&self) -> usize {
352         self.n_bytes
353     }
354 }
355