xref: /webrtc/interceptor/src/twcc/receiver/mod.rs (revision 7220f446)
1 mod receiver_stream;
2 #[cfg(test)]
3 mod receiver_test;
4 
5 use crate::twcc::sender::TRANSPORT_CC_URI;
6 use crate::twcc::Recorder;
7 use crate::*;
8 use receiver_stream::ReceiverStream;
9 
10 use rtp::extension::transport_cc_extension::TransportCcExtension;
11 use std::time::Duration;
12 use tokio::sync::{mpsc, Mutex};
13 use tokio::time::MissedTickBehavior;
14 use util::Unmarshal;
15 use waitgroup::WaitGroup;
16 
17 /// ReceiverBuilder is a InterceptorBuilder for a SenderInterceptor
18 #[derive(Default)]
19 pub struct ReceiverBuilder {
20     interval: Option<Duration>,
21 }
22 
23 impl ReceiverBuilder {
24     /// with_interval sets send interval for the interceptor.
with_interval(mut self, interval: Duration) -> ReceiverBuilder25     pub fn with_interval(mut self, interval: Duration) -> ReceiverBuilder {
26         self.interval = Some(interval);
27         self
28     }
29 }
30 
31 impl InterceptorBuilder for ReceiverBuilder {
build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>>32     fn build(&self, _id: &str) -> Result<Arc<dyn Interceptor + Send + Sync>> {
33         let (close_tx, close_rx) = mpsc::channel(1);
34         let (packet_chan_tx, packet_chan_rx) = mpsc::channel(1);
35         Ok(Arc::new(Receiver {
36             internal: Arc::new(ReceiverInternal {
37                 interval: if let Some(interval) = &self.interval {
38                     *interval
39                 } else {
40                     Duration::from_millis(100)
41                 },
42                 recorder: Mutex::new(Recorder::default()),
43                 packet_chan_rx: Mutex::new(Some(packet_chan_rx)),
44                 streams: Mutex::new(HashMap::new()),
45                 close_rx: Mutex::new(Some(close_rx)),
46             }),
47             start_time: tokio::time::Instant::now(),
48             packet_chan_tx,
49             wg: Mutex::new(Some(WaitGroup::new())),
50             close_tx: Mutex::new(Some(close_tx)),
51         }))
52     }
53 }
54 
55 struct Packet {
56     hdr: rtp::header::Header,
57     sequence_number: u16,
58     arrival_time: i64,
59     ssrc: u32,
60 }
61 
62 struct ReceiverInternal {
63     interval: Duration,
64     recorder: Mutex<Recorder>,
65     packet_chan_rx: Mutex<Option<mpsc::Receiver<Packet>>>,
66     streams: Mutex<HashMap<u32, Arc<ReceiverStream>>>,
67     close_rx: Mutex<Option<mpsc::Receiver<()>>>,
68 }
69 
70 /// Receiver sends transport wide congestion control reports as specified in:
71 /// https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01
72 pub struct Receiver {
73     internal: Arc<ReceiverInternal>,
74 
75     // we use tokio's Instant because it makes testing easier via `tokio::time::advance`.
76     start_time: tokio::time::Instant,
77     packet_chan_tx: mpsc::Sender<Packet>,
78 
79     wg: Mutex<Option<WaitGroup>>,
80     close_tx: Mutex<Option<mpsc::Sender<()>>>,
81 }
82 
83 impl Receiver {
84     /// builder returns a new ReceiverBuilder.
builder() -> ReceiverBuilder85     pub fn builder() -> ReceiverBuilder {
86         ReceiverBuilder::default()
87     }
88 
is_closed(&self) -> bool89     async fn is_closed(&self) -> bool {
90         let close_tx = self.close_tx.lock().await;
91         close_tx.is_none()
92     }
93 
run( rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>, internal: Arc<ReceiverInternal>, ) -> Result<()>94     async fn run(
95         rtcp_writer: Arc<dyn RTCPWriter + Send + Sync>,
96         internal: Arc<ReceiverInternal>,
97     ) -> Result<()> {
98         let mut close_rx = {
99             let mut close_rx = internal.close_rx.lock().await;
100             if let Some(close_rx) = close_rx.take() {
101                 close_rx
102             } else {
103                 return Err(Error::ErrInvalidCloseRx);
104             }
105         };
106         let mut packet_chan_rx = {
107             let mut packet_chan_rx = internal.packet_chan_rx.lock().await;
108             if let Some(packet_chan_rx) = packet_chan_rx.take() {
109                 packet_chan_rx
110             } else {
111                 return Err(Error::ErrInvalidPacketRx);
112             }
113         };
114 
115         let a = Attributes::new();
116         let mut ticker = tokio::time::interval(internal.interval);
117         ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
118         loop {
119             tokio::select! {
120                 _ = close_rx.recv() =>{
121                     return Ok(());
122                 }
123                 p = packet_chan_rx.recv() => {
124                     if let Some(p) = p {
125                         let mut recorder = internal.recorder.lock().await;
126                         recorder.record(p.ssrc, p.sequence_number, p.arrival_time);
127                     }
128                 }
129                 _ = ticker.tick() =>{
130                     // build and send twcc
131                     let pkts = {
132                         let mut recorder = internal.recorder.lock().await;
133                         recorder.build_feedback_packet()
134                     };
135 
136                     if pkts.is_empty() {
137                         continue;
138                     }
139 
140                     if let Err(err) = rtcp_writer.write(&pkts, &a).await{
141                         log::error!("rtcp_writer.write got err: {}", err);
142                     }
143                 }
144             }
145         }
146     }
147 }
148 
149 #[async_trait]
150 impl Interceptor for Receiver {
151     /// bind_rtcp_reader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might
152     /// change in the future. The returned method will be called once per packet batch.
bind_rtcp_reader( &self, reader: Arc<dyn RTCPReader + Send + Sync>, ) -> Arc<dyn RTCPReader + Send + Sync>153     async fn bind_rtcp_reader(
154         &self,
155         reader: Arc<dyn RTCPReader + Send + Sync>,
156     ) -> Arc<dyn RTCPReader + Send + Sync> {
157         reader
158     }
159 
160     /// bind_rtcp_writer lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method
161     /// will be called once per packet batch.
bind_rtcp_writer( &self, writer: Arc<dyn RTCPWriter + Send + Sync>, ) -> Arc<dyn RTCPWriter + Send + Sync>162     async fn bind_rtcp_writer(
163         &self,
164         writer: Arc<dyn RTCPWriter + Send + Sync>,
165     ) -> Arc<dyn RTCPWriter + Send + Sync> {
166         if self.is_closed().await {
167             return writer;
168         }
169 
170         {
171             let mut recorder = self.internal.recorder.lock().await;
172             *recorder = Recorder::new(rand::random::<u32>());
173         }
174 
175         let mut w = {
176             let wait_group = self.wg.lock().await;
177             wait_group.as_ref().map(|wg| wg.worker())
178         };
179         let writer2 = Arc::clone(&writer);
180         let internal = Arc::clone(&self.internal);
181         tokio::spawn(async move {
182             let _d = w.take();
183             if let Err(err) = Receiver::run(writer2, internal).await {
184                 log::warn!("bind_rtcp_writer TWCC Sender::run got error: {}", err);
185             }
186         });
187 
188         writer
189     }
190 
191     /// bind_local_stream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method
192     /// will be called once per rtp packet.
bind_local_stream( &self, _info: &StreamInfo, writer: Arc<dyn RTPWriter + Send + Sync>, ) -> Arc<dyn RTPWriter + Send + Sync>193     async fn bind_local_stream(
194         &self,
195         _info: &StreamInfo,
196         writer: Arc<dyn RTPWriter + Send + Sync>,
197     ) -> Arc<dyn RTPWriter + Send + Sync> {
198         writer
199     }
200 
201     /// unbind_local_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
unbind_local_stream(&self, _info: &StreamInfo)202     async fn unbind_local_stream(&self, _info: &StreamInfo) {}
203 
204     /// bind_remote_stream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method
205     /// will be called once per rtp packet.
bind_remote_stream( &self, info: &StreamInfo, reader: Arc<dyn RTPReader + Send + Sync>, ) -> Arc<dyn RTPReader + Send + Sync>206     async fn bind_remote_stream(
207         &self,
208         info: &StreamInfo,
209         reader: Arc<dyn RTPReader + Send + Sync>,
210     ) -> Arc<dyn RTPReader + Send + Sync> {
211         let mut hdr_ext_id = 0u8;
212         for e in &info.rtp_header_extensions {
213             if e.uri == TRANSPORT_CC_URI {
214                 hdr_ext_id = e.id as u8;
215                 break;
216             }
217         }
218         if hdr_ext_id == 0 {
219             // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID
220             return reader;
221         }
222 
223         let stream = Arc::new(ReceiverStream::new(
224             reader,
225             hdr_ext_id,
226             info.ssrc,
227             self.packet_chan_tx.clone(),
228             self.start_time,
229         ));
230 
231         {
232             let mut streams = self.internal.streams.lock().await;
233             streams.insert(info.ssrc, Arc::clone(&stream));
234         }
235 
236         stream
237     }
238 
239     /// unbind_remote_stream is called when the Stream is removed. It can be used to clean up any data related to that track.
unbind_remote_stream(&self, info: &StreamInfo)240     async fn unbind_remote_stream(&self, info: &StreamInfo) {
241         let mut streams = self.internal.streams.lock().await;
242         streams.remove(&info.ssrc);
243     }
244 
245     /// close closes the Interceptor, cleaning up any data if necessary.
close(&self) -> Result<()>246     async fn close(&self) -> Result<()> {
247         {
248             let mut close_tx = self.close_tx.lock().await;
249             close_tx.take();
250         }
251 
252         {
253             let mut wait_group = self.wg.lock().await;
254             if let Some(wg) = wait_group.take() {
255                 wg.wait().await;
256             }
257         }
258 
259         Ok(())
260     }
261 }
262