1 use super::*;
2 use crate::error::Result;
3 use crate::protection_profile::*;
4 
5 use rtcp::payload_feedbacks::*;
6 use util::conn::conn_pipe::*;
7 
8 use bytes::{Bytes, BytesMut};
9 use std::sync::Arc;
10 use tokio::sync::{mpsc, Mutex};
11 
build_session_srtcp_pair() -> Result<(Session, Session)>12 async fn build_session_srtcp_pair() -> Result<(Session, Session)> {
13     let (ua, ub) = pipe();
14 
15     let ca = Config {
16         profile: ProtectionProfile::Aes128CmHmacSha1_80,
17         keys: SessionKeys {
18             local_master_key: vec![
19                 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
20                 0x41, 0x39,
21             ],
22             local_master_salt: vec![
23                 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
24             ],
25             remote_master_key: vec![
26                 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
27                 0x41, 0x39,
28             ],
29             remote_master_salt: vec![
30                 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
31             ],
32         },
33 
34         local_rtp_options: None,
35         remote_rtp_options: None,
36 
37         local_rtcp_options: None,
38         remote_rtcp_options: None,
39     };
40 
41     let cb = Config {
42         profile: ProtectionProfile::Aes128CmHmacSha1_80,
43         keys: SessionKeys {
44             local_master_key: vec![
45                 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
46                 0x41, 0x39,
47             ],
48             local_master_salt: vec![
49                 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
50             ],
51             remote_master_key: vec![
52                 0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE,
53                 0x41, 0x39,
54             ],
55             remote_master_salt: vec![
56                 0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6,
57             ],
58         },
59 
60         local_rtp_options: None,
61         remote_rtp_options: None,
62 
63         local_rtcp_options: None,
64         remote_rtcp_options: None,
65     };
66 
67     let sa = Session::new(Arc::new(ua), ca, false).await?;
68     let sb = Session::new(Arc::new(ub), cb, false).await?;
69 
70     Ok((sa, sb))
71 }
72 
73 const TEST_SSRC: u32 = 5000;
74 
75 #[tokio::test]
test_session_srtcp_accept() -> Result<()>76 async fn test_session_srtcp_accept() -> Result<()> {
77     let (sa, sb) = build_session_srtcp_pair().await?;
78 
79     let rtcp_packet = picture_loss_indication::PictureLossIndication {
80         media_ssrc: TEST_SSRC,
81         ..Default::default()
82     };
83 
84     let test_payload = rtcp_packet.marshal()?;
85     sa.write_rtcp(&rtcp_packet).await?;
86 
87     let read_stream = sb.accept().await?;
88     let ssrc = read_stream.get_ssrc();
89     assert_eq!(
90         ssrc, TEST_SSRC,
91         "SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})"
92     );
93 
94     let mut read_buffer = BytesMut::with_capacity(test_payload.len());
95     read_buffer.resize(test_payload.len(), 0u8);
96     read_stream.read(&mut read_buffer).await?;
97 
98     assert_eq!(
99         &test_payload[..],
100         &read_buffer[..],
101         "Sent buffer does not match the one received exp({:?}) actual({:?})",
102         &test_payload[..],
103         &read_buffer[..]
104     );
105 
106     sa.close().await?;
107     sb.close().await?;
108 
109     Ok(())
110 }
111 
112 #[tokio::test]
test_session_srtcp_listen() -> Result<()>113 async fn test_session_srtcp_listen() -> Result<()> {
114     let (sa, sb) = build_session_srtcp_pair().await?;
115 
116     let rtcp_packet = picture_loss_indication::PictureLossIndication {
117         media_ssrc: TEST_SSRC,
118         ..Default::default()
119     };
120 
121     let test_payload = rtcp_packet.marshal()?;
122     let read_stream = sb.open(TEST_SSRC).await;
123 
124     sa.write_rtcp(&rtcp_packet).await?;
125 
126     let mut read_buffer = BytesMut::with_capacity(test_payload.len());
127     read_buffer.resize(test_payload.len(), 0u8);
128     read_stream.read(&mut read_buffer).await?;
129 
130     assert_eq!(
131         &test_payload[..],
132         &read_buffer[..],
133         "Sent buffer does not match the one received exp({:?}) actual({:?})",
134         &test_payload[..],
135         &read_buffer[..]
136     );
137 
138     sa.close().await?;
139     sb.close().await?;
140 
141     Ok(())
142 }
143 
encrypt_srtcp( context: &mut Context, pkt: &(dyn rtcp::packet::Packet + Send + Sync), ) -> Result<Bytes>144 fn encrypt_srtcp(
145     context: &mut Context,
146     pkt: &(dyn rtcp::packet::Packet + Send + Sync),
147 ) -> Result<Bytes> {
148     let decrypted = pkt.marshal()?;
149     let encrypted = context.encrypt_rtcp(&decrypted)?;
150     Ok(encrypted)
151 }
152 
153 const PLI_PACKET_SIZE: usize = 8;
154 
get_sender_ssrc(read_stream: &Arc<Stream>) -> Result<u32>155 async fn get_sender_ssrc(read_stream: &Arc<Stream>) -> Result<u32> {
156     let auth_tag_size = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len();
157 
158     let mut read_buffer = BytesMut::with_capacity(PLI_PACKET_SIZE + auth_tag_size);
159     read_buffer.resize(PLI_PACKET_SIZE + auth_tag_size, 0u8);
160 
161     let (n, _) = read_stream.read_rtcp(&mut read_buffer).await?;
162     let mut reader = &read_buffer[0..n];
163     let pli = picture_loss_indication::PictureLossIndication::unmarshal(&mut reader)?;
164 
165     Ok(pli.sender_ssrc)
166 }
167 
168 #[tokio::test]
test_session_srtcp_replay_protection() -> Result<()>169 async fn test_session_srtcp_replay_protection() -> Result<()> {
170     let (sa, sb) = build_session_srtcp_pair().await?;
171 
172     let read_stream = sb.open(TEST_SSRC).await;
173 
174     // Generate test packets
175     let mut packets = vec![];
176     let mut expected_ssrc = vec![];
177     {
178         let mut local_context = sa.local_context.lock().await;
179         for i in 0..0x10u32 {
180             expected_ssrc.push(i);
181 
182             let packet = picture_loss_indication::PictureLossIndication {
183                 media_ssrc: TEST_SSRC,
184                 sender_ssrc: i,
185             };
186 
187             let encrypted = encrypt_srtcp(&mut local_context, &packet)?;
188 
189             packets.push(encrypted);
190         }
191     }
192 
193     let (done_tx, mut done_rx) = mpsc::channel::<()>(1);
194 
195     let received_ssrc = Arc::new(Mutex::new(vec![]));
196     let cloned_received_ssrc = Arc::clone(&received_ssrc);
197     let count = expected_ssrc.len();
198 
199     tokio::spawn(async move {
200         let mut i = 0;
201         while i < count {
202             match get_sender_ssrc(&read_stream).await {
203                 Ok(ssrc) => {
204                     let mut r = cloned_received_ssrc.lock().await;
205                     r.push(ssrc);
206 
207                     i += 1;
208                 }
209                 Err(_) => break,
210             }
211         }
212 
213         drop(done_tx);
214     });
215 
216     // Write with replay attack
217     for packet in &packets {
218         sa.udp_tx.send(packet).await?;
219 
220         // Immediately replay
221         sa.udp_tx.send(packet).await?;
222     }
223     for packet in &packets {
224         // Delayed replay
225         sa.udp_tx.send(packet).await?;
226     }
227 
228     done_rx.recv().await;
229 
230     sa.close().await?;
231     sb.close().await?;
232 
233     {
234         let received_ssrc = received_ssrc.lock().await;
235         assert_eq!(&expected_ssrc[..], &received_ssrc[..]);
236     }
237 
238     Ok(())
239 }
240