xref: /webrtc/srtp/src/session/mod.rs (revision ffe74184)
1 #[cfg(test)]
2 mod session_rtcp_test;
3 #[cfg(test)]
4 mod session_rtp_test;
5 
6 use crate::{
7     config::*,
8     context::*,
9     error::{Error, Result},
10     option::*,
11     stream::*,
12 };
13 use util::{conn::Conn, marshal::*};
14 
15 use bytes::Bytes;
16 use std::collections::HashSet;
17 use std::{
18     collections::HashMap,
19     marker::{Send, Sync},
20     sync::Arc,
21 };
22 use tokio::sync::{mpsc, Mutex};
23 
24 const DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW: usize = 64;
25 const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
26 
27 /// Session implements io.ReadWriteCloser and provides a bi-directional SRTP session
28 /// SRTP itself does not have a design like this, but it is common in most applications
29 /// for local/remote to each have their own keying material. This provides those patterns
30 /// instead of making everyone re-implement
31 pub struct Session {
32     local_context: Arc<Mutex<Context>>,
33     streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
34     new_stream_rx: Arc<Mutex<mpsc::Receiver<Arc<Stream>>>>,
35     close_stream_tx: mpsc::Sender<u32>,
36     close_session_tx: mpsc::Sender<()>,
37     pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
38     is_rtp: bool,
39 }
40 
41 impl Session {
new( conn: Arc<dyn Conn + Send + Sync>, config: Config, is_rtp: bool, ) -> Result<Self>42     pub async fn new(
43         conn: Arc<dyn Conn + Send + Sync>,
44         config: Config,
45         is_rtp: bool,
46     ) -> Result<Self> {
47         let local_context = Context::new(
48             &config.keys.local_master_key,
49             &config.keys.local_master_salt,
50             config.profile,
51             config.local_rtp_options,
52             config.local_rtcp_options,
53         )?;
54 
55         let mut remote_context = Context::new(
56             &config.keys.remote_master_key,
57             &config.keys.remote_master_salt,
58             config.profile,
59             if config.remote_rtp_options.is_none() {
60                 Some(srtp_replay_protection(
61                     DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW,
62                 ))
63             } else {
64                 config.remote_rtp_options
65             },
66             if config.remote_rtcp_options.is_none() {
67                 Some(srtcp_replay_protection(
68                     DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW,
69                 ))
70             } else {
71                 config.remote_rtcp_options
72             },
73         )?;
74 
75         let streams_map = Arc::new(Mutex::new(HashMap::new()));
76         let (mut new_stream_tx, new_stream_rx) = mpsc::channel(8);
77         let (close_stream_tx, mut close_stream_rx) = mpsc::channel(8);
78         let (close_session_tx, mut close_session_rx) = mpsc::channel(8);
79         let udp_tx = Arc::clone(&conn);
80         let udp_rx = Arc::clone(&conn);
81         let cloned_streams_map = Arc::clone(&streams_map);
82         let cloned_close_stream_tx = close_stream_tx.clone();
83 
84         tokio::spawn(async move {
85             let mut buf = vec![0u8; 8192];
86 
87             loop {
88                 let incoming_stream = Session::incoming(
89                     &udp_rx,
90                     &mut buf,
91                     &cloned_streams_map,
92                     &cloned_close_stream_tx,
93                     &mut new_stream_tx,
94                     &mut remote_context,
95                     is_rtp,
96                 );
97                 let close_stream = close_stream_rx.recv();
98                 let close_session = close_session_rx.recv();
99 
100                 tokio::select! {
101                     result = incoming_stream => match result{
102                         Ok(()) => {},
103                         Err(err) => log::info!("{}", err),
104                     },
105                     opt = close_stream => if let Some(ssrc) = opt {
106                         Session::close_stream(&cloned_streams_map, ssrc).await
107                     },
108                     _ = close_session => break
109                 }
110             }
111         });
112 
113         Ok(Session {
114             local_context: Arc::new(Mutex::new(local_context)),
115             streams_map,
116             new_stream_rx: Arc::new(Mutex::new(new_stream_rx)),
117             close_stream_tx,
118             close_session_tx,
119             udp_tx,
120             is_rtp,
121         })
122     }
123 
close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32)124     async fn close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32) {
125         let mut streams = streams_map.lock().await;
126         streams.remove(&ssrc);
127     }
128 
incoming( udp_rx: &Arc<dyn Conn + Send + Sync>, buf: &mut [u8], streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, close_stream_tx: &mpsc::Sender<u32>, new_stream_tx: &mut mpsc::Sender<Arc<Stream>>, remote_context: &mut Context, is_rtp: bool, ) -> Result<()>129     async fn incoming(
130         udp_rx: &Arc<dyn Conn + Send + Sync>,
131         buf: &mut [u8],
132         streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
133         close_stream_tx: &mpsc::Sender<u32>,
134         new_stream_tx: &mut mpsc::Sender<Arc<Stream>>,
135         remote_context: &mut Context,
136         is_rtp: bool,
137     ) -> Result<()> {
138         let n = udp_rx.recv(buf).await?;
139         if n == 0 {
140             return Err(Error::SessionEof);
141         }
142 
143         let decrypted = if is_rtp {
144             remote_context.decrypt_rtp(&buf[0..n])?
145         } else {
146             remote_context.decrypt_rtcp(&buf[0..n])?
147         };
148 
149         let mut buf = &decrypted[..];
150         let ssrcs = if is_rtp {
151             vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc]
152         } else {
153             let pkts = rtcp::packet::unmarshal(&mut buf)?;
154             destination_ssrc(&pkts)
155         };
156 
157         for ssrc in ssrcs {
158             let (stream, is_new) =
159                 Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
160                     .await;
161             if is_new {
162                 log::trace!(
163                     "srtp session got new {} stream {}",
164                     if is_rtp { "rtp" } else { "rtcp" },
165                     ssrc
166                 );
167                 new_stream_tx.send(Arc::clone(&stream)).await?;
168             }
169 
170             match stream.buffer.write(&decrypted).await {
171                 Ok(_) => {}
172                 Err(err) => {
173                     // Silently drop data when the buffer is full.
174                     if util::Error::ErrBufferFull != err {
175                         return Err(err.into());
176                     }
177                 }
178             }
179         }
180 
181         Ok(())
182     }
183 
get_or_create_stream( streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, close_stream_tx: mpsc::Sender<u32>, is_rtp: bool, ssrc: u32, ) -> (Arc<Stream>, bool)184     async fn get_or_create_stream(
185         streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
186         close_stream_tx: mpsc::Sender<u32>,
187         is_rtp: bool,
188         ssrc: u32,
189     ) -> (Arc<Stream>, bool) {
190         let mut streams = streams_map.lock().await;
191 
192         if let Some(stream) = streams.get(&ssrc) {
193             (Arc::clone(stream), false)
194         } else {
195             let stream = Arc::new(Stream::new(ssrc, close_stream_tx, is_rtp));
196             streams.insert(ssrc, Arc::clone(&stream));
197             (stream, true)
198         }
199     }
200 
201     /// open on the given SSRC to create a stream, it can be used
202     /// if you want a certain SSRC, but don't want to wait for Accept
open(&self, ssrc: u32) -> Arc<Stream>203     pub async fn open(&self, ssrc: u32) -> Arc<Stream> {
204         let (stream, _) = Session::get_or_create_stream(
205             &self.streams_map,
206             self.close_stream_tx.clone(),
207             self.is_rtp,
208             ssrc,
209         )
210         .await;
211 
212         stream
213     }
214 
215     /// accept returns a stream to handle RTCP for a single SSRC
accept(&self) -> Result<Arc<Stream>>216     pub async fn accept(&self) -> Result<Arc<Stream>> {
217         let mut new_stream_rx = self.new_stream_rx.lock().await;
218         let result = new_stream_rx.recv().await;
219         if let Some(stream) = result {
220             Ok(stream)
221         } else {
222             Err(Error::SessionSrtpAlreadyClosed)
223         }
224     }
225 
close(&self) -> Result<()>226     pub async fn close(&self) -> Result<()> {
227         self.close_session_tx.send(()).await?;
228 
229         Ok(())
230     }
231 
write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize>232     pub async fn write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize> {
233         if self.is_rtp != is_rtp {
234             return Err(Error::SessionRtpRtcpTypeMismatch);
235         }
236 
237         let encrypted = {
238             let mut local_context = self.local_context.lock().await;
239 
240             if is_rtp {
241                 local_context.encrypt_rtp(buf)?
242             } else {
243                 local_context.encrypt_rtcp(buf)?
244             }
245         };
246 
247         Ok(self.udp_tx.send(&encrypted).await?)
248     }
249 
write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize>250     pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize> {
251         let raw = pkt.marshal()?;
252         self.write(&raw, true).await
253     }
254 
write_rtcp( &self, pkt: &(dyn rtcp::packet::Packet + Send + Sync), ) -> Result<usize>255     pub async fn write_rtcp(
256         &self,
257         pkt: &(dyn rtcp::packet::Packet + Send + Sync),
258     ) -> Result<usize> {
259         let raw = pkt.marshal()?;
260         self.write(&raw, false).await
261     }
262 }
263 
264 /// create a list of Destination SSRCs
265 /// that's a superset of all Destinations in the slice.
destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32>266 fn destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32> {
267     let mut ssrc_set = HashSet::new();
268     for p in pkts {
269         for ssrc in p.destination_ssrc() {
270             ssrc_set.insert(ssrc);
271         }
272     }
273     ssrc_set.into_iter().collect()
274 }
275