1 use super::*;
2 
3 use crate::nack::UINT16SIZE_HALF;
4 
5 use util::sync::Mutex;
6 use util::Unmarshal;
7 
8 struct GeneratorStreamInternal {
9     packets: Vec<u64>,
10     size: u16,
11     end: u16,
12     started: bool,
13     last_consecutive: u16,
14 }
15 
16 impl GeneratorStreamInternal {
new(log2_size_minus_6: u8) -> Self17     fn new(log2_size_minus_6: u8) -> Self {
18         GeneratorStreamInternal {
19             packets: vec![0u64; 1 << log2_size_minus_6],
20             size: 1 << (log2_size_minus_6 + 6),
21             end: 0,
22             started: false,
23             last_consecutive: 0,
24         }
25     }
26 
add(&mut self, seq: u16)27     fn add(&mut self, seq: u16) {
28         if !self.started {
29             self.set_received(seq);
30             self.end = seq;
31             self.started = true;
32             self.last_consecutive = seq;
33             return;
34         }
35 
36         let last_consecutive_plus1 = self.last_consecutive.wrapping_add(1);
37         let diff = seq.wrapping_sub(self.end);
38         if diff == 0 {
39             return;
40         } else if diff < UINT16SIZE_HALF {
41             // this means a positive diff, in other words seq > end (with counting for rollovers)
42             let mut i = self.end.wrapping_add(1);
43             while i != seq {
44                 // clear packets between end and seq (these may contain packets from a "size" ago)
45                 self.del_received(i);
46                 i = i.wrapping_add(1);
47             }
48             self.end = seq;
49 
50             let seq_sub_last_consecutive = seq.wrapping_sub(self.last_consecutive);
51             if last_consecutive_plus1 == seq {
52                 self.last_consecutive = seq;
53             } else if seq_sub_last_consecutive > self.size {
54                 let diff = seq.wrapping_sub(self.size);
55                 self.last_consecutive = diff;
56                 self.fix_last_consecutive(); // there might be valid packets at the beginning of the buffer now
57             }
58         } else if last_consecutive_plus1 == seq {
59             // negative diff, seq < end (with counting for rollovers)
60             self.last_consecutive = seq;
61             self.fix_last_consecutive(); // there might be other valid packets after seq
62         }
63 
64         self.set_received(seq);
65     }
66 
get(&self, seq: u16) -> bool67     fn get(&self, seq: u16) -> bool {
68         let diff = self.end.wrapping_sub(seq);
69         if diff >= UINT16SIZE_HALF {
70             return false;
71         }
72 
73         if diff >= self.size {
74             return false;
75         }
76 
77         self.get_received(seq)
78     }
79 
missing_seq_numbers(&self, skip_last_n: u16) -> Vec<u16>80     fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec<u16> {
81         let until = self.end.wrapping_sub(skip_last_n);
82         let diff = until.wrapping_sub(self.last_consecutive);
83         if diff >= UINT16SIZE_HALF {
84             // until < s.last_consecutive (counting for rollover)
85             return vec![];
86         }
87 
88         let mut missing_packet_seq_nums = vec![];
89         let mut i = self.last_consecutive.wrapping_add(1);
90         let util_plus1 = until.wrapping_add(1);
91         while i != util_plus1 {
92             if !self.get_received(i) {
93                 missing_packet_seq_nums.push(i);
94             }
95             i = i.wrapping_add(1);
96         }
97 
98         missing_packet_seq_nums
99     }
100 
set_received(&mut self, seq: u16)101     fn set_received(&mut self, seq: u16) {
102         let pos = (seq % self.size) as usize;
103         self.packets[pos / 64] |= 1u64 << (pos % 64);
104     }
105 
del_received(&mut self, seq: u16)106     fn del_received(&mut self, seq: u16) {
107         let pos = (seq % self.size) as usize;
108         self.packets[pos / 64] &= u64::MAX ^ (1u64 << (pos % 64));
109     }
110 
get_received(&self, seq: u16) -> bool111     fn get_received(&self, seq: u16) -> bool {
112         let pos = (seq % self.size) as usize;
113         (self.packets[pos / 64] & (1u64 << (pos % 64))) != 0
114     }
115 
fix_last_consecutive(&mut self)116     fn fix_last_consecutive(&mut self) {
117         let mut i = self.last_consecutive.wrapping_add(1);
118         while i != self.end.wrapping_add(1) && self.get_received(i) {
119             // find all consecutive packets
120             i = i.wrapping_add(1);
121         }
122         self.last_consecutive = i.wrapping_sub(1);
123     }
124 }
125 
126 pub(super) struct GeneratorStream {
127     parent_rtp_reader: Arc<dyn RTPReader + Send + Sync>,
128 
129     internal: Mutex<GeneratorStreamInternal>,
130 }
131 
132 impl GeneratorStream {
new(log2_size_minus_6: u8, reader: Arc<dyn RTPReader + Send + Sync>) -> Self133     pub(super) fn new(log2_size_minus_6: u8, reader: Arc<dyn RTPReader + Send + Sync>) -> Self {
134         GeneratorStream {
135             parent_rtp_reader: reader,
136             internal: Mutex::new(GeneratorStreamInternal::new(log2_size_minus_6)),
137         }
138     }
139 
missing_seq_numbers(&self, skip_last_n: u16) -> Vec<u16>140     pub(super) fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec<u16> {
141         let internal = self.internal.lock();
142         internal.missing_seq_numbers(skip_last_n)
143     }
144 
add(&self, seq: u16)145     pub(super) fn add(&self, seq: u16) {
146         let mut internal = self.internal.lock();
147         internal.add(seq);
148     }
149 }
150 
151 /// RTPReader is used by Interceptor.bind_remote_stream.
152 #[async_trait]
153 impl RTPReader for GeneratorStream {
154     /// read a rtp packet
read(&self, buf: &mut [u8], a: &Attributes) -> Result<(usize, Attributes)>155     async fn read(&self, buf: &mut [u8], a: &Attributes) -> Result<(usize, Attributes)> {
156         let (n, attr) = self.parent_rtp_reader.read(buf, a).await?;
157 
158         let mut b = &buf[..n];
159         let pkt = rtp::packet::Packet::unmarshal(&mut b)?;
160         self.add(pkt.header.sequence_number);
161 
162         Ok((n, attr))
163     }
164 }
165 
166 #[cfg(test)]
167 mod test {
168     use super::*;
169 
170     #[test]
test_generator_stream() -> Result<()>171     fn test_generator_stream() -> Result<()> {
172         let tests: Vec<u16> = vec![
173             0, 1, 127, 128, 129, 511, 512, 513, 32767, 32768, 32769, 65407, 65408, 65409, 65534,
174             65535,
175         ];
176         for start in tests {
177             let mut rl = GeneratorStreamInternal::new(1);
178 
179             let all = |min: u16, max: u16| -> Vec<u16> {
180                 let mut result = vec![];
181                 let mut i = min;
182                 let max_plus_1 = max.wrapping_add(1);
183                 while i != max_plus_1 {
184                     result.push(i);
185                     i = i.wrapping_add(1);
186                 }
187                 result
188             };
189 
190             let join = |parts: &[&[u16]]| -> Vec<u16> {
191                 let mut result = vec![];
192                 for p in parts {
193                     result.extend_from_slice(p);
194                 }
195                 result
196             };
197 
198             let add = |rl: &mut GeneratorStreamInternal, nums: &[u16]| {
199                 for n in nums {
200                     let seq = start.wrapping_add(*n);
201                     rl.add(seq);
202                 }
203             };
204 
205             let assert_get = |rl: &GeneratorStreamInternal, nums: &[u16]| {
206                 for n in nums {
207                     let seq = start.wrapping_add(*n);
208                     assert!(rl.get(seq), "not found: {seq}");
209                 }
210             };
211 
212             let assert_not_get = |rl: &GeneratorStreamInternal, nums: &[u16]| {
213                 for n in nums {
214                     let seq = start.wrapping_add(*n);
215                     assert!(
216                         !rl.get(seq),
217                         "packet found: start {}, n {}, seq {}",
218                         start,
219                         *n,
220                         seq
221                     );
222                 }
223             };
224 
225             let assert_missing = |rl: &GeneratorStreamInternal, skip_last_n: u16, nums: &[u16]| {
226                 let missing = rl.missing_seq_numbers(skip_last_n);
227                 let mut want = vec![];
228                 for n in nums {
229                     let seq = start.wrapping_add(*n);
230                     want.push(seq);
231                 }
232                 assert_eq!(want, missing, "missing want/got, ");
233             };
234 
235             let assert_last_consecutive = |rl: &GeneratorStreamInternal, last_consecutive: u16| {
236                 let want = last_consecutive.wrapping_add(start);
237                 assert_eq!(rl.last_consecutive, want, "invalid last_consecutive want");
238             };
239 
240             add(&mut rl, &[0]);
241             assert_get(&rl, &[0]);
242             assert_missing(&rl, 0, &[]);
243             assert_last_consecutive(&rl, 0); // first element added
244 
245             add(&mut rl, &all(1, 127));
246             assert_get(&rl, &all(1, 127));
247             assert_missing(&rl, 0, &[]);
248             assert_last_consecutive(&rl, 127);
249 
250             add(&mut rl, &[128]);
251             assert_get(&rl, &[128]);
252             assert_not_get(&rl, &[0]);
253             assert_missing(&rl, 0, &[]);
254             assert_last_consecutive(&rl, 128);
255 
256             add(&mut rl, &[130]);
257             assert_get(&rl, &[130]);
258             assert_not_get(&rl, &[1, 2, 129]);
259             assert_missing(&rl, 0, &[129]);
260             assert_last_consecutive(&rl, 128);
261 
262             add(&mut rl, &[333]);
263             assert_get(&rl, &[333]);
264             assert_not_get(&rl, &all(0, 332));
265             assert_missing(&rl, 0, &all(206, 332)); // all 127 elements missing before 333
266             assert_missing(&rl, 10, &all(206, 323)); // skip last 10 packets (324-333) from check
267             assert_last_consecutive(&rl, 205); // lastConsecutive is still out of the buffer
268 
269             add(&mut rl, &[329]);
270             assert_get(&rl, &[329]);
271             assert_missing(&rl, 0, &join(&[&all(206, 328), &all(330, 332)]));
272             assert_missing(&rl, 5, &join(&[&all(206, 328)])); // skip last 5 packets (329-333) from check
273             assert_last_consecutive(&rl, 205);
274 
275             add(&mut rl, &all(207, 320));
276             assert_get(&rl, &all(207, 320));
277             assert_missing(&rl, 0, &join(&[&[206], &all(321, 328), &all(330, 332)]));
278             assert_last_consecutive(&rl, 205);
279 
280             add(&mut rl, &[334]);
281             assert_get(&rl, &[334]);
282             assert_not_get(&rl, &[206]);
283             assert_missing(&rl, 0, &join(&[&all(321, 328), &all(330, 332)]));
284             assert_last_consecutive(&rl, 320); // head of buffer is full of consecutive packages
285 
286             add(&mut rl, &all(322, 328));
287             assert_get(&rl, &all(322, 328));
288             assert_missing(&rl, 0, &join(&[&[321], &all(330, 332)]));
289             assert_last_consecutive(&rl, 320);
290 
291             add(&mut rl, &[321]);
292             assert_get(&rl, &[321]);
293             assert_missing(&rl, 0, &all(330, 332));
294             assert_last_consecutive(&rl, 329); // after adding a single missing packet, lastConsecutive should jump forward
295         }
296 
297         Ok(())
298     }
299 
300     #[test]
test_generator_stream_rollover()301     fn test_generator_stream_rollover() {
302         let mut rl = GeneratorStreamInternal::new(1);
303         // Make sure it doesn't panic.
304         rl.add(65533);
305         rl.add(65535);
306         rl.add(65534);
307 
308         let mut rl = GeneratorStreamInternal::new(1);
309         // Make sure it doesn't panic.
310         rl.add(65534);
311         rl.add(0);
312         rl.add(65535);
313     }
314 }
315