xref: /webrtc/srtp/src/context/context_test.rs (revision 5d8fe953)
1 use super::*;
2 use crate::key_derivation::*;
3 
4 use bytes::Bytes;
5 use lazy_static::lazy_static;
6 
7 const CIPHER_CONTEXT_ALGO: ProtectionProfile = ProtectionProfile::Aes128CmHmacSha1_80;
8 const DEFAULT_SSRC: u32 = 0;
9 
10 #[test]
test_context_roc() -> Result<()>11 fn test_context_roc() -> Result<()> {
12     let key_len = CIPHER_CONTEXT_ALGO.key_len();
13     let salt_len = CIPHER_CONTEXT_ALGO.salt_len();
14 
15     let mut c = Context::new(
16         &vec![0; key_len],
17         &vec![0; salt_len],
18         CIPHER_CONTEXT_ALGO,
19         None,
20         None,
21     )?;
22 
23     let roc = c.get_roc(123);
24     assert!(roc.is_none(), "ROC must return None for unused SSRC");
25 
26     c.set_roc(123, 100);
27     let roc = c.get_roc(123);
28     if let Some(r) = roc {
29         assert_eq!(r, 100, "ROC is set to 100, but returned {r}")
30     } else {
31         panic!("ROC must return value for used SSRC");
32     }
33 
34     Ok(())
35 }
36 
37 #[test]
test_context_index() -> Result<()>38 fn test_context_index() -> Result<()> {
39     let key_len = CIPHER_CONTEXT_ALGO.key_len();
40     let salt_len = CIPHER_CONTEXT_ALGO.salt_len();
41 
42     let mut c = Context::new(
43         &vec![0; key_len],
44         &vec![0; salt_len],
45         CIPHER_CONTEXT_ALGO,
46         None,
47         None,
48     )?;
49 
50     let index = c.get_index(123);
51     assert!(index.is_none(), "Index must return None for unused SSRC");
52 
53     c.set_index(123, 100);
54     let index = c.get_index(123);
55     if let Some(i) = index {
56         assert_eq!(i, 100, "Index is set to 100, but returned {i}");
57     } else {
58         panic!("Index must return true for used SSRC")
59     }
60 
61     Ok(())
62 }
63 
64 #[test]
test_key_len() -> Result<()>65 fn test_key_len() -> Result<()> {
66     let key_len = CIPHER_CONTEXT_ALGO.key_len();
67     let salt_len = CIPHER_CONTEXT_ALGO.salt_len();
68 
69     let result = Context::new(&[], &vec![0; salt_len], CIPHER_CONTEXT_ALGO, None, None);
70     assert!(result.is_err(), "CreateContext accepted a 0 length key");
71 
72     let result = Context::new(&vec![0; key_len], &[], CIPHER_CONTEXT_ALGO, None, None);
73     assert!(result.is_err(), "CreateContext accepted a 0 length salt");
74 
75     let result = Context::new(
76         &vec![0; key_len],
77         &vec![0; salt_len],
78         CIPHER_CONTEXT_ALGO,
79         None,
80         None,
81     );
82     assert!(
83         result.is_ok(),
84         "CreateContext failed with a valid length key and salt"
85     );
86 
87     Ok(())
88 }
89 
90 #[test]
test_valid_packet_counter() -> Result<()>91 fn test_valid_packet_counter() -> Result<()> {
92     let master_key = vec![
93         0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28,
94         0x89,
95     ];
96     let master_salt = vec![
97         0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c,
98     ];
99 
100     let srtp_session_salt = aes_cm_key_derivation(
101         LABEL_SRTP_SALT,
102         &master_key,
103         &master_salt,
104         0,
105         master_salt.len(),
106     )?;
107 
108     let s = SrtpSsrcState {
109         ssrc: 4160032510,
110         ..Default::default()
111     };
112     let expected_counter = vec![
113         0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00,
114         0x00,
115     ];
116     let counter = generate_counter(32846, s.rollover_counter, s.ssrc, &srtp_session_salt)?;
117     assert_eq!(
118         counter, expected_counter,
119         "Session Key {counter:?} does not match expected {expected_counter:?}",
120     );
121 
122     Ok(())
123 }
124 
125 #[test]
test_rollover_count() -> Result<()>126 fn test_rollover_count() -> Result<()> {
127     let mut s = SrtpSsrcState {
128         ssrc: DEFAULT_SSRC,
129         ..Default::default()
130     };
131 
132     // Set initial seqnum
133     let roc = s.next_rollover_count(65530);
134     assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
135     s.update_rollover_count(65530);
136 
137     // Invalid packets never update ROC
138     s.next_rollover_count(0);
139     s.next_rollover_count(0x4000);
140     s.next_rollover_count(0x8000);
141     s.next_rollover_count(0xFFFF);
142     s.next_rollover_count(0);
143 
144     // We rolled over to 0
145     let roc = s.next_rollover_count(0);
146     assert_eq!(roc, 1, "rolloverCounter was not updated after it crossed 0");
147     s.update_rollover_count(0);
148 
149     let roc = s.next_rollover_count(65530);
150     assert_eq!(
151         roc, 0,
152         "rolloverCounter was not updated when it rolled back, failed to handle out of order"
153     );
154     s.update_rollover_count(65530);
155 
156     let roc = s.next_rollover_count(5);
157     assert_eq!(
158         roc, 1,
159         "rolloverCounter was not updated when it rolled over initial, to handle out of order"
160     );
161     s.update_rollover_count(5);
162 
163     s.next_rollover_count(6);
164     s.update_rollover_count(6);
165 
166     s.next_rollover_count(7);
167     s.update_rollover_count(7);
168 
169     let roc = s.next_rollover_count(8);
170     assert_eq!(
171         roc, 1,
172         "rolloverCounter was improperly updated for non-significant packets"
173     );
174     s.update_rollover_count(8);
175 
176     // valid packets never update ROC
177     let roc = s.next_rollover_count(0x4000);
178     assert_eq!(
179         roc, 1,
180         "rolloverCounter was improperly updated for non-significant packets"
181     );
182     s.update_rollover_count(0x4000);
183 
184     let roc = s.next_rollover_count(0x8000);
185     assert_eq!(
186         roc, 1,
187         "rolloverCounter was improperly updated for non-significant packets"
188     );
189     s.update_rollover_count(0x8000);
190 
191     let roc = s.next_rollover_count(0xFFFF);
192     assert_eq!(
193         roc, 1,
194         "rolloverCounter was improperly updated for non-significant packets"
195     );
196     s.update_rollover_count(0xFFFF);
197 
198     let roc = s.next_rollover_count(0);
199     assert_eq!(
200         roc, 2,
201         "rolloverCounter must be incremented after wrapping, got {roc}"
202     );
203 
204     Ok(())
205 }
206 
207 lazy_static! {
208     static ref MASTER_KEY: Bytes = Bytes::from_static(&[
209         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
210         0x0f,
211     ]);
212     static ref MASTER_SALT: Bytes = Bytes::from_static(&[
213         0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab,
214     ]);
215     static ref DECRYPTED_RTP_PACKET: Bytes = Bytes::from_static(&[
216         0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab,
217         0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab,
218     ]);
219     static ref ENCRYPTED_RTP_PACKET: Bytes = Bytes::from_static(&[
220         0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e,
221         0xde, 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97,
222         0x68, 0x26, 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, 0x27, 0xe8, 0xa3, 0x92,
223     ]);
224     static ref DECRYPTED_RTCP_PACKET: Bytes = Bytes::from_static(&[
225         0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab,
226         0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab,
227     ]);
228     static ref ENCRYPTED_RTCP_PACKET: Bytes = Bytes::from_static(&[
229         0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a,
230         0x55, 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3,
231         0xb4, 0x46, 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01,
232     ]);
233 }
234 
235 #[test]
test_encrypt_rtp()236 fn test_encrypt_rtp() {
237     let mut ctx = Context::new(
238         &MASTER_KEY,
239         &MASTER_SALT,
240         ProtectionProfile::AeadAes128Gcm,
241         None,
242         None,
243     )
244     .expect("Error creating srtp context");
245 
246     let gotten_encrypted_rtp_packet = ctx
247         .encrypt_rtp(&DECRYPTED_RTP_PACKET)
248         .expect("Error encrypting rtp payload");
249 
250     assert_eq!(gotten_encrypted_rtp_packet, *ENCRYPTED_RTP_PACKET)
251 }
252 
253 #[test]
test_decrypt_rtp()254 fn test_decrypt_rtp() {
255     let mut ctx = Context::new(
256         &MASTER_KEY,
257         &MASTER_SALT,
258         ProtectionProfile::AeadAes128Gcm,
259         None,
260         None,
261     )
262     .expect("Error creating srtp context");
263 
264     let gotten_decrypted_rtp_packet = ctx
265         .decrypt_rtp(&ENCRYPTED_RTP_PACKET)
266         .expect("Error decrypting rtp payload");
267 
268     assert_eq!(gotten_decrypted_rtp_packet, *DECRYPTED_RTP_PACKET)
269 }
270 
271 #[test]
test_encrypt_rtcp()272 fn test_encrypt_rtcp() {
273     let mut ctx = Context::new(
274         &MASTER_KEY,
275         &MASTER_SALT,
276         ProtectionProfile::AeadAes128Gcm,
277         None,
278         None,
279     )
280     .expect("Error creating srtp context");
281 
282     let gotten_encrypted_rtcp_packet = ctx
283         .encrypt_rtcp(&DECRYPTED_RTCP_PACKET)
284         .expect("Error encrypting rtcp payload");
285 
286     assert_eq!(gotten_encrypted_rtcp_packet, *ENCRYPTED_RTCP_PACKET)
287 }
288 
289 #[test]
test_decrypt_rtcp()290 fn test_decrypt_rtcp() {
291     let mut ctx = Context::new(
292         &MASTER_KEY,
293         &MASTER_SALT,
294         ProtectionProfile::AeadAes128Gcm,
295         None,
296         None,
297     )
298     .expect("Error creating srtp context");
299 
300     let gotten_decrypted_rtcp_packet = ctx
301         .decrypt_rtcp(&ENCRYPTED_RTCP_PACKET)
302         .expect("Error decrypting rtcp payload");
303 
304     assert_eq!(gotten_decrypted_rtcp_packet, *DECRYPTED_RTCP_PACKET)
305 }
306