xref: /webrtc/interceptor/src/nack/generator/mod.rs (revision ffe74184)
1 mod generator_stream;
2 #[cfg(test)]
3 mod generator_test;
4 
5 use generator_stream::GeneratorStream;
6 
7 use crate::error::{Error, Result};
8 use crate::stream_info::StreamInfo;
9 use crate::{Attributes, Interceptor, RTCPReader, RTPReader, RTPWriter};
10 use crate::{InterceptorBuilder, RTCPWriter};
11 
12 use crate::nack::stream_support_nack;
13 
14 use async_trait::async_trait;
15 use rtcp::transport_feedbacks::transport_layer_nack::{
16     nack_pairs_from_sequence_numbers, TransportLayerNack,
17 };
18 use std::collections::HashMap;
19 use std::sync::Arc;
20 use std::time::Duration;
21 use tokio::sync::{mpsc, Mutex};
22 use waitgroup::WaitGroup;
23 
24 /// GeneratorBuilder can be used to configure Generator Interceptor
25 #[derive(Default)]
26 pub struct GeneratorBuilder {
27     log2_size_minus_6: Option<u8>,
28     skip_last_n: Option<u16>,
29     interval: Option<Duration>,
30 }
31 
32 impl GeneratorBuilder {
33     /// with_size sets the size of the interceptor.
34     /// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768
with_log2_size_minus_6(mut self, log2_size_minus_6: u8) -> GeneratorBuilder35     pub fn with_log2_size_minus_6(mut self, log2_size_minus_6: u8) -> GeneratorBuilder {
36         self.log2_size_minus_6 = Some(log2_size_minus_6);
37         self
38     }
39 
40     /// with_skip_last_n sets the number of packets (n-1 packets before the last received packets) to ignore when generating
41     /// nack requests.
with_skip_last_n(mut self, skip_last_n: u16) -> GeneratorBuilder42     pub fn with_skip_last_n(mut self, skip_last_n: u16) -> GeneratorBuilder {
43         self.skip_last_n = Some(skip_last_n);
44         self
45     }
46 
47     /// with_interval sets the nack send interval for the interceptor
with_interval(mut self, interval: Duration) -> GeneratorBuilder48     pub fn with_interval(mut self, interval: Duration) -> GeneratorBuilder {
49         self.interval = Some(interval);
50         self
51     }
52 }
53 
54 impl InterceptorBuilder for GeneratorBuilder {
build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>>55     fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
56         let (close_tx, close_rx) = mpsc::channel(1);
57         Ok(Arc::new(Generator {
58             internal: Arc::new(GeneratorInternal {
59                 log2_size_minus_6: if let Some(log2_size_minus_6) = self.log2_size_minus_6 {
60                     log2_size_minus_6
61                 } else {
62                     13 - 6 // 8192 = 1 << 13
63                 },
64                 skip_last_n: if let Some(skip_last_n) = self.skip_last_n {
65                     skip_last_n
66                 } else {
67                     0
68                 },
69                 interval: if let Some(interval) = self.interval {
70                     interval
71                 } else {
72                     Duration::from_millis(100)
73                 },
74 
75                 streams: Mutex::new(HashMap::new()),
76                 close_rx: Mutex::new(Some(close_rx)),
77             }),
78 
79             wg: Mutex::new(Some(WaitGroup::new())),
80             close_tx: Mutex::new(Some(close_tx)),
81         }))
82     }
83 }
84 
85 struct GeneratorInternal {
86     log2_size_minus_6: u8,
87     skip_last_n: u16,
88     interval: Duration,
89 
90     streams: Mutex<HashMap<u32, Arc<GeneratorStream>>>,
91     close_rx: Mutex<Option<mpsc::Receiver<()>>>,
92 }
93 
94 /// Generator interceptor generates nack feedback messages.
95 pub struct Generator {
96     internal: Arc<GeneratorInternal>,
97 
98     pub(crate) wg: Mutex<Option<WaitGroup>>,
99     pub(crate) close_tx: Mutex<Option<mpsc::Sender<()>>>,
100 }
101 
102 impl Generator {
103     /// builder returns a new GeneratorBuilder.
builder() -> GeneratorBuilder104     pub fn builder() -> GeneratorBuilder {
105         GeneratorBuilder::default()
106     }
107 
is_closed(&self) -> bool108     async fn is_closed(&self) -> bool {
109         let close_tx = self.close_tx.lock().await;
110         close_tx.is_none()
111     }
112 
run( rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>, internal: Arc<GeneratorInternal>, ) -> Result<()>113     async fn run(
114         rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>,
115         internal: Arc<GeneratorInternal>,
116     ) -> Result<()> {
117         let mut ticker = tokio::time::interval(internal.interval);
118         let mut close_rx = {
119             let mut close_rx = internal.close_rx.lock().await;
120             if let Some(close) = close_rx.take() {
121                 close
122             } else {
123                 return Err(Error::ErrInvalidCloseRx);
124             }
125         };
126 
127         let sender_ssrc = rand::random::<u32>();
128         loop {
129             tokio::select! {
130                 _ = ticker.tick() =>{
131                     let nacks = {
132                         let mut nacks = vec![];
133                         let streams = internal.streams.lock().await;
134                         for (ssrc, stream) in streams.iter() {
135                             let missing = stream.missing_seq_numbers(internal.skip_last_n);
136                             if missing.is_empty(){
137                                 continue;
138                             }
139 
140                             nacks.push(TransportLayerNack{
141                                 sender_ssrc,
142                                 media_ssrc: *ssrc,
143                                 nacks:  nack_pairs_from_sequence_numbers(&missing),
144                             });
145                         }
146                         nacks
147                     };
148 
149                     let a = Attributes::new();
150                     for nack in nacks{
151                         if let Err(err) = rtcp_writer.write(&[Box::new(nack)], &a).await{
152                             log::warn!("failed sending nack: {}", err);
153                         }
154                     }
155                 }
156                 _ = close_rx.recv() =>{
157                     return Ok(());
158                 }
159             }
160         }
161     }
162 }
163 
164 #[async_trait]
165 impl Interceptor for Generator {
166     /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might
167     /// change in the future. The returned method will be called once per packet batch.
bind_rtcp_reader( &self, reader: Arc<dyn RTCPReader + Send + Sync>, ) -> Arc<dyn RTCPReader + Send + Sync>168     async fn bind_rtcp_reader(
169         &self,
170         reader: Arc<dyn RTCPReader + Send + Sync>,
171     ) -> Arc<dyn RTCPReader + Send + Sync> {
172         reader
173     }
174 
175     /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method
176     /// will be called once per packet batch.
bind_rtcp_writer( &self, writer: Arc<dyn RTCPWriter + Send + Sync>, ) -> Arc<dyn RTCPWriter + Send + Sync>177     async fn bind_rtcp_writer(
178         &self,
179         writer: Arc<dyn RTCPWriter + Send + Sync>,
180     ) -> Arc<dyn RTCPWriter + Send + Sync> {
181         if self.is_closed().await {
182             return writer;
183         }
184 
185         let mut w = {
186             let wait_group = self.wg.lock().await;
187             wait_group.as_ref().map(|wg| wg.worker())
188         };
189         let writer2 = Arc::clone(&writer);
190         let internal = Arc::clone(&self.internal);
191         tokio::spawn(async move {
192             let _d = w.take();
193             if let Err(err) = Generator::run(writer2, internal).await {
194                 log::warn!("bind_rtcp_writer NACK Generator::run got error: {}", err);
195             }
196         });
197 
198         writer
199     }
200 
201     /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
202     /// will be called once per rtp packet.
bind_local_stream( &self, _info: &StreamInfo, writer: Arc<dyn RTPWriter + Send + Sync>, ) -> Arc<dyn RTPWriter + Send + Sync>203     async fn bind_local_stream(
204         &self,
205         _info: &StreamInfo,
206         writer: Arc<dyn RTPWriter + Send + Sync>,
207     ) -> Arc<dyn RTPWriter + Send + Sync> {
208         writer
209     }
210 
211     /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
unbind_local_stream(&self, _info: &StreamInfo)212     async fn unbind_local_stream(&self, _info: &StreamInfo) {}
213 
214     /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
215     /// will be called once per rtp packet.
bind_remote_stream( &self, info: &StreamInfo, reader: Arc<dyn RTPReader + Send + Sync>, ) -> Arc<dyn RTPReader + Send + Sync>216     async fn bind_remote_stream(
217         &self,
218         info: &StreamInfo,
219         reader: Arc<dyn RTPReader + Send + Sync>,
220     ) -> Arc<dyn RTPReader + Send + Sync> {
221         if !stream_support_nack(info) {
222             return reader;
223         }
224 
225         let stream = Arc::new(GeneratorStream::new(
226             self.internal.log2_size_minus_6,
227             reader,
228         ));
229         {
230             let mut streams = self.internal.streams.lock().await;
231             streams.insert(info.ssrc, Arc::clone(&stream));
232         }
233 
234         stream
235     }
236 
237     /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
unbind_remote_stream(&self, info: &StreamInfo)238     async fn unbind_remote_stream(&self, info: &StreamInfo) {
239         let mut receive_logs = self.internal.streams.lock().await;
240         receive_logs.remove(&info.ssrc);
241     }
242 
243     /// close closes the Interceptor, cleaning up any data if necessary.
close(&self) -> Result<()>244     async fn close(&self) -> Result<()> {
245         {
246             let mut close_tx = self.close_tx.lock().await;
247             close_tx.take();
248         }
249 
250         {
251             let mut wait_group = self.wg.lock().await;
252             if let Some(wg) = wait_group.take() {
253                 wg.wait().await;
254             }
255         }
256 
257         Ok(())
258     }
259 }
260