1 #[cfg(test)]
2 mod session_rtcp_test;
3 #[cfg(test)]
4 mod session_rtp_test;
5
6 use crate::{
7 config::*,
8 context::*,
9 error::{Error, Result},
10 option::*,
11 stream::*,
12 };
13 use util::{conn::Conn, marshal::*};
14
15 use bytes::Bytes;
16 use std::collections::HashSet;
17 use std::{
18 collections::HashMap,
19 marker::{Send, Sync},
20 sync::Arc,
21 };
22 use tokio::sync::{mpsc, Mutex};
23
24 const DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW: usize = 64;
25 const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
26
27 /// Session implements io.ReadWriteCloser and provides a bi-directional SRTP session
28 /// SRTP itself does not have a design like this, but it is common in most applications
29 /// for local/remote to each have their own keying material. This provides those patterns
30 /// instead of making everyone re-implement
31 pub struct Session {
32 local_context: Arc<Mutex<Context>>,
33 streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
34 new_stream_rx: Arc<Mutex<mpsc::Receiver<Arc<Stream>>>>,
35 close_stream_tx: mpsc::Sender<u32>,
36 close_session_tx: mpsc::Sender<()>,
37 pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
38 is_rtp: bool,
39 }
40
41 impl Session {
new( conn: Arc<dyn Conn + Send + Sync>, config: Config, is_rtp: bool, ) -> Result<Self>42 pub async fn new(
43 conn: Arc<dyn Conn + Send + Sync>,
44 config: Config,
45 is_rtp: bool,
46 ) -> Result<Self> {
47 let local_context = Context::new(
48 &config.keys.local_master_key,
49 &config.keys.local_master_salt,
50 config.profile,
51 config.local_rtp_options,
52 config.local_rtcp_options,
53 )?;
54
55 let mut remote_context = Context::new(
56 &config.keys.remote_master_key,
57 &config.keys.remote_master_salt,
58 config.profile,
59 if config.remote_rtp_options.is_none() {
60 Some(srtp_replay_protection(
61 DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW,
62 ))
63 } else {
64 config.remote_rtp_options
65 },
66 if config.remote_rtcp_options.is_none() {
67 Some(srtcp_replay_protection(
68 DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW,
69 ))
70 } else {
71 config.remote_rtcp_options
72 },
73 )?;
74
75 let streams_map = Arc::new(Mutex::new(HashMap::new()));
76 let (mut new_stream_tx, new_stream_rx) = mpsc::channel(8);
77 let (close_stream_tx, mut close_stream_rx) = mpsc::channel(8);
78 let (close_session_tx, mut close_session_rx) = mpsc::channel(8);
79 let udp_tx = Arc::clone(&conn);
80 let udp_rx = Arc::clone(&conn);
81 let cloned_streams_map = Arc::clone(&streams_map);
82 let cloned_close_stream_tx = close_stream_tx.clone();
83
84 tokio::spawn(async move {
85 let mut buf = vec![0u8; 8192];
86
87 loop {
88 let incoming_stream = Session::incoming(
89 &udp_rx,
90 &mut buf,
91 &cloned_streams_map,
92 &cloned_close_stream_tx,
93 &mut new_stream_tx,
94 &mut remote_context,
95 is_rtp,
96 );
97 let close_stream = close_stream_rx.recv();
98 let close_session = close_session_rx.recv();
99
100 tokio::select! {
101 result = incoming_stream => match result{
102 Ok(()) => {},
103 Err(err) => log::info!("{}", err),
104 },
105 opt = close_stream => if let Some(ssrc) = opt {
106 Session::close_stream(&cloned_streams_map, ssrc).await
107 },
108 _ = close_session => break
109 }
110 }
111 });
112
113 Ok(Session {
114 local_context: Arc::new(Mutex::new(local_context)),
115 streams_map,
116 new_stream_rx: Arc::new(Mutex::new(new_stream_rx)),
117 close_stream_tx,
118 close_session_tx,
119 udp_tx,
120 is_rtp,
121 })
122 }
123
close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32)124 async fn close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32) {
125 let mut streams = streams_map.lock().await;
126 streams.remove(&ssrc);
127 }
128
incoming( udp_rx: &Arc<dyn Conn + Send + Sync>, buf: &mut [u8], streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, close_stream_tx: &mpsc::Sender<u32>, new_stream_tx: &mut mpsc::Sender<Arc<Stream>>, remote_context: &mut Context, is_rtp: bool, ) -> Result<()>129 async fn incoming(
130 udp_rx: &Arc<dyn Conn + Send + Sync>,
131 buf: &mut [u8],
132 streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
133 close_stream_tx: &mpsc::Sender<u32>,
134 new_stream_tx: &mut mpsc::Sender<Arc<Stream>>,
135 remote_context: &mut Context,
136 is_rtp: bool,
137 ) -> Result<()> {
138 let n = udp_rx.recv(buf).await?;
139 if n == 0 {
140 return Err(Error::SessionEof);
141 }
142
143 let decrypted = if is_rtp {
144 remote_context.decrypt_rtp(&buf[0..n])?
145 } else {
146 remote_context.decrypt_rtcp(&buf[0..n])?
147 };
148
149 let mut buf = &decrypted[..];
150 let ssrcs = if is_rtp {
151 vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc]
152 } else {
153 let pkts = rtcp::packet::unmarshal(&mut buf)?;
154 destination_ssrc(&pkts)
155 };
156
157 for ssrc in ssrcs {
158 let (stream, is_new) =
159 Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
160 .await;
161 if is_new {
162 log::trace!(
163 "srtp session got new {} stream {}",
164 if is_rtp { "rtp" } else { "rtcp" },
165 ssrc
166 );
167 new_stream_tx.send(Arc::clone(&stream)).await?;
168 }
169
170 match stream.buffer.write(&decrypted).await {
171 Ok(_) => {}
172 Err(err) => {
173 // Silently drop data when the buffer is full.
174 if util::Error::ErrBufferFull != err {
175 return Err(err.into());
176 }
177 }
178 }
179 }
180
181 Ok(())
182 }
183
get_or_create_stream( streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, close_stream_tx: mpsc::Sender<u32>, is_rtp: bool, ssrc: u32, ) -> (Arc<Stream>, bool)184 async fn get_or_create_stream(
185 streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
186 close_stream_tx: mpsc::Sender<u32>,
187 is_rtp: bool,
188 ssrc: u32,
189 ) -> (Arc<Stream>, bool) {
190 let mut streams = streams_map.lock().await;
191
192 if let Some(stream) = streams.get(&ssrc) {
193 (Arc::clone(stream), false)
194 } else {
195 let stream = Arc::new(Stream::new(ssrc, close_stream_tx, is_rtp));
196 streams.insert(ssrc, Arc::clone(&stream));
197 (stream, true)
198 }
199 }
200
201 /// open on the given SSRC to create a stream, it can be used
202 /// if you want a certain SSRC, but don't want to wait for Accept
open(&self, ssrc: u32) -> Arc<Stream>203 pub async fn open(&self, ssrc: u32) -> Arc<Stream> {
204 let (stream, _) = Session::get_or_create_stream(
205 &self.streams_map,
206 self.close_stream_tx.clone(),
207 self.is_rtp,
208 ssrc,
209 )
210 .await;
211
212 stream
213 }
214
215 /// accept returns a stream to handle RTCP for a single SSRC
accept(&self) -> Result<Arc<Stream>>216 pub async fn accept(&self) -> Result<Arc<Stream>> {
217 let mut new_stream_rx = self.new_stream_rx.lock().await;
218 let result = new_stream_rx.recv().await;
219 if let Some(stream) = result {
220 Ok(stream)
221 } else {
222 Err(Error::SessionSrtpAlreadyClosed)
223 }
224 }
225
close(&self) -> Result<()>226 pub async fn close(&self) -> Result<()> {
227 self.close_session_tx.send(()).await?;
228
229 Ok(())
230 }
231
write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize>232 pub async fn write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize> {
233 if self.is_rtp != is_rtp {
234 return Err(Error::SessionRtpRtcpTypeMismatch);
235 }
236
237 let encrypted = {
238 let mut local_context = self.local_context.lock().await;
239
240 if is_rtp {
241 local_context.encrypt_rtp(buf)?
242 } else {
243 local_context.encrypt_rtcp(buf)?
244 }
245 };
246
247 Ok(self.udp_tx.send(&encrypted).await?)
248 }
249
write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize>250 pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize> {
251 let raw = pkt.marshal()?;
252 self.write(&raw, true).await
253 }
254
write_rtcp( &self, pkt: &(dyn rtcp::packet::Packet + Send + Sync), ) -> Result<usize>255 pub async fn write_rtcp(
256 &self,
257 pkt: &(dyn rtcp::packet::Packet + Send + Sync),
258 ) -> Result<usize> {
259 let raw = pkt.marshal()?;
260 self.write(&raw, false).await
261 }
262 }
263
264 /// create a list of Destination SSRCs
265 /// that's a superset of all Destinations in the slice.
destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32>266 fn destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32> {
267 let mut ssrc_set = HashSet::new();
268 for p in pkts {
269 for ssrc in p.destination_ssrc() {
270 ssrc_set.insert(ssrc);
271 }
272 }
273 ssrc_set.into_iter().collect()
274 }
275