1 use super::*;
2 use crate::error::Result;
3 use crate::protection_profile::*;
4
5 use bytes::{Bytes, BytesMut};
6 use std::{collections::HashMap, sync::Arc};
7 use tokio::{
8 net::UdpSocket,
9 sync::{mpsc, Mutex},
10 };
11
build_session_srtp_pair() -> Result<(Session, Session)>12 async fn build_session_srtp_pair() -> Result<(Session, Session)> {
13 let ua = UdpSocket::bind("127.0.0.1:0").await?;
14 let ub = UdpSocket::bind("127.0.0.1:0").await?;
15
16 ua.connect(ub.local_addr()?).await?;
17 ub.connect(ua.local_addr()?).await?;
18
19 let ca = Config {
20 profile: ProtectionProfile::Aes128CmHmacSha1_80,
21 keys: SessionKeys {
22 local_master_key: vec![
23 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
24 0x41, 0x39,
25 ],
26 local_master_salt: vec![
27 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
28 ],
29 remote_master_key: vec![
30 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
31 0x41, 0x39,
32 ],
33 remote_master_salt: vec![
34 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
35 ],
36 },
37
38 local_rtp_options: None,
39 remote_rtp_options: None,
40
41 local_rtcp_options: None,
42 remote_rtcp_options: None,
43 };
44
45 let cb = Config {
46 profile: ProtectionProfile::Aes128CmHmacSha1_80,
47 keys: SessionKeys {
48 local_master_key: vec![
49 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
50 0x41, 0x39,
51 ],
52 local_master_salt: vec![
53 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
54 ],
55 remote_master_key: vec![
56 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
57 0x41, 0x39,
58 ],
59 remote_master_salt: vec![
60 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
61 ],
62 },
63
64 local_rtp_options: None,
65 remote_rtp_options: None,
66
67 local_rtcp_options: None,
68 remote_rtcp_options: None,
69 };
70
71 let sa = Session::new(Arc::new(ua), ca, true).await?;
72 let sb = Session::new(Arc::new(ub), cb, true).await?;
73
74 Ok((sa, sb))
75 }
76
77 const TEST_SSRC: u32 = 5000;
78 const RTP_HEADER_SIZE: usize = 12;
79
80 #[tokio::test]
test_session_srtp_accept() -> Result<()>81 async fn test_session_srtp_accept() -> Result<()> {
82 let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]);
83 let mut read_buffer = BytesMut::with_capacity(RTP_HEADER_SIZE + test_payload.len());
84 read_buffer.resize(RTP_HEADER_SIZE + test_payload.len(), 0u8);
85 let (sa, sb) = build_session_srtp_pair().await?;
86
87 let packet = rtp::packet::Packet {
88 header: rtp::header::Header {
89 ssrc: TEST_SSRC,
90 ..Default::default()
91 },
92 payload: test_payload.clone(),
93 };
94 sa.write_rtp(&packet).await?;
95
96 let read_stream = sb.accept().await?;
97 let ssrc = read_stream.get_ssrc();
98 assert_eq!(
99 ssrc, TEST_SSRC,
100 "SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})"
101 );
102
103 read_stream.read(&mut read_buffer).await?;
104
105 assert_eq!(
106 &test_payload[..],
107 &read_buffer[RTP_HEADER_SIZE..],
108 "Sent buffer does not match the one received exp({:?}) actual({:?})",
109 &test_payload[..],
110 &read_buffer[RTP_HEADER_SIZE..]
111 );
112
113 sa.close().await?;
114 sb.close().await?;
115
116 Ok(())
117 }
118
119 #[tokio::test]
test_session_srtp_listen() -> Result<()>120 async fn test_session_srtp_listen() -> Result<()> {
121 let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]);
122 let mut read_buffer = BytesMut::with_capacity(RTP_HEADER_SIZE + test_payload.len());
123 read_buffer.resize(RTP_HEADER_SIZE + test_payload.len(), 0u8);
124 let (sa, sb) = build_session_srtp_pair().await?;
125
126 let packet = rtp::packet::Packet {
127 header: rtp::header::Header {
128 ssrc: TEST_SSRC,
129 ..Default::default()
130 },
131 payload: test_payload.clone(),
132 };
133
134 let read_stream = sb.open(TEST_SSRC).await;
135
136 sa.write_rtp(&packet).await?;
137
138 read_stream.read(&mut read_buffer).await?;
139
140 assert_eq!(
141 &test_payload[..],
142 &read_buffer[RTP_HEADER_SIZE..],
143 "Sent buffer does not match the one received exp({:?}) actual({:?})",
144 &test_payload[..],
145 &read_buffer[RTP_HEADER_SIZE..]
146 );
147
148 sa.close().await?;
149 sb.close().await?;
150
151 Ok(())
152 }
153
154 #[tokio::test]
test_session_srtp_multi_ssrc() -> Result<()>155 async fn test_session_srtp_multi_ssrc() -> Result<()> {
156 let ssrcs = vec![5000, 5001, 5002];
157 let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]);
158 let mut read_buffer = BytesMut::with_capacity(RTP_HEADER_SIZE + test_payload.len());
159 read_buffer.resize(RTP_HEADER_SIZE + test_payload.len(), 0u8);
160 let (sa, sb) = build_session_srtp_pair().await?;
161
162 let mut read_streams = HashMap::new();
163 for ssrc in &ssrcs {
164 let read_stream = sb.open(*ssrc).await;
165 read_streams.insert(*ssrc, read_stream);
166 }
167
168 for ssrc in &ssrcs {
169 let packet = rtp::packet::Packet {
170 header: rtp::header::Header {
171 ssrc: *ssrc,
172 ..Default::default()
173 },
174 payload: test_payload.clone(),
175 };
176 sa.write_rtp(&packet).await?;
177
178 if let Some(read_stream) = read_streams.get_mut(ssrc) {
179 read_stream.read(&mut read_buffer).await?;
180
181 assert_eq!(
182 &test_payload[..],
183 &read_buffer[RTP_HEADER_SIZE..],
184 "Sent buffer does not match the one received exp({:?}) actual({:?})",
185 &test_payload[..],
186 &read_buffer[RTP_HEADER_SIZE..]
187 );
188 } else {
189 panic!("ssrc {} not found", *ssrc);
190 }
191 }
192
193 sa.close().await?;
194 sb.close().await?;
195
196 Ok(())
197 }
198
encrypt_srtp(context: &mut Context, pkt: &rtp::packet::Packet) -> Result<Bytes>199 fn encrypt_srtp(context: &mut Context, pkt: &rtp::packet::Packet) -> Result<Bytes> {
200 let decrypted = pkt.marshal()?;
201 let encrypted = context.encrypt_rtp(&decrypted)?;
202 Ok(encrypted)
203 }
204
payload_srtp( read_stream: &Arc<Stream>, header_size: usize, expected_payload: &[u8], ) -> Result<u16>205 async fn payload_srtp(
206 read_stream: &Arc<Stream>,
207 header_size: usize,
208 expected_payload: &[u8],
209 ) -> Result<u16> {
210 let mut read_buffer = BytesMut::with_capacity(header_size + expected_payload.len());
211 read_buffer.resize(header_size + expected_payload.len(), 0u8);
212
213 let (n, hdr) = read_stream.read_rtp(&mut read_buffer).await?;
214
215 assert_eq!(
216 expected_payload,
217 &read_buffer[header_size..n],
218 "Sent buffer does not match the one received exp({:?}) actual({:?})",
219 expected_payload,
220 &read_buffer[header_size..n]
221 );
222
223 Ok(hdr.sequence_number)
224 }
225
226 #[tokio::test]
test_session_srtp_replay_protection() -> Result<()>227 async fn test_session_srtp_replay_protection() -> Result<()> {
228 let test_payload = Bytes::from_static(&[0x00, 0x01, 0x03, 0x04]);
229
230 let (sa, sb) = build_session_srtp_pair().await?;
231
232 let read_stream = sb.open(TEST_SSRC).await;
233
234 // Generate test packets
235 let mut packets = vec![];
236 let mut expected_sequence_number = vec![];
237 {
238 let mut local_context = sa.local_context.lock().await;
239 let mut i = 0xFFF0u16;
240 while i != 0x10 {
241 expected_sequence_number.push(i);
242
243 let packet = rtp::packet::Packet {
244 header: rtp::header::Header {
245 ssrc: TEST_SSRC,
246 sequence_number: i,
247 ..Default::default()
248 },
249 payload: test_payload.clone(),
250 };
251
252 let encrypted = encrypt_srtp(&mut local_context, &packet)?;
253
254 packets.push(encrypted);
255
256 if i == 0xFFFF {
257 i = 0;
258 } else {
259 i += 1;
260 }
261 }
262 }
263
264 let (done_tx, mut done_rx) = mpsc::channel::<()>(1);
265
266 let received_sequence_number = Arc::new(Mutex::new(vec![]));
267 let cloned_received_sequence_number = Arc::clone(&received_sequence_number);
268 let count = expected_sequence_number.len();
269
270 tokio::spawn(async move {
271 let mut i = 0;
272 while i < count {
273 let seq = payload_srtp(&read_stream, RTP_HEADER_SIZE, &test_payload)
274 .await
275 .unwrap();
276 let mut r = cloned_received_sequence_number.lock().await;
277 r.push(seq);
278
279 i += 1;
280 }
281
282 drop(done_tx);
283 });
284
285 // Write with replay attack
286 for packet in &packets {
287 sa.udp_tx.send(packet).await?;
288
289 // Immediately replay
290 sa.udp_tx.send(packet).await?;
291 }
292 for packet in &packets {
293 // Delayed replay
294 sa.udp_tx.send(packet).await?;
295 }
296
297 done_rx.recv().await;
298
299 sa.close().await?;
300 sb.close().await?;
301
302 {
303 let received_sequence_number = received_sequence_number.lock().await;
304 assert_eq!(&received_sequence_number[..], &expected_sequence_number[..]);
305 }
306
307 Ok(())
308 }
309