xref: /webrtc/srtp/src/context/mod.rs (revision ffe74184)
1 #[cfg(test)]
2 mod context_test;
3 #[cfg(test)]
4 mod srtcp_test;
5 #[cfg(test)]
6 mod srtp_test;
7 
8 use crate::error::Result;
9 use crate::{
10     cipher::cipher_aead_aes_gcm::*, cipher::cipher_aes_cm_hmac_sha1::*, cipher::*, error::Error,
11     option::*, protection_profile::*,
12 };
13 
14 use std::collections::HashMap;
15 use util::replay_detector::*;
16 
17 pub mod srtcp;
18 pub mod srtp;
19 
20 const MAX_ROC_DISORDER: u16 = 100;
21 
22 /// Encrypt/Decrypt state for a single SRTP SSRC
23 #[derive(Default)]
24 pub(crate) struct SrtpSsrcState {
25     ssrc: u32,
26     rollover_counter: u32,
27     rollover_has_processed: bool,
28     last_sequence_number: u16,
29     replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
30 }
31 
32 /// Encrypt/Decrypt state for a single SRTCP SSRC
33 #[derive(Default)]
34 pub(crate) struct SrtcpSsrcState {
35     srtcp_index: usize,
36     ssrc: u32,
37     replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
38 }
39 
40 impl SrtpSsrcState {
next_rollover_count(&self, sequence_number: u16) -> u3241     pub fn next_rollover_count(&self, sequence_number: u16) -> u32 {
42         let mut roc = self.rollover_counter;
43 
44         if !self.rollover_has_processed {
45         } else if sequence_number == 0 {
46             // We exactly hit the rollover count
47 
48             // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
49             // otherwise we already incremented for disorder
50             if self.last_sequence_number > MAX_ROC_DISORDER {
51                 roc += 1;
52             }
53         } else if self.last_sequence_number < MAX_ROC_DISORDER
54             && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
55         {
56             // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
57             // So we fell behind, drop to account for jitter
58             roc -= 1;
59         } else if sequence_number < MAX_ROC_DISORDER
60             && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
61         {
62             // our current is within a MAX_ROCDISORDER of 0
63             // and our last sequence number was a high sequence number, increment to account for jitter
64             roc += 1;
65         }
66 
67         roc
68     }
69 
70     /// https://tools.ietf.org/html/rfc3550#appendix-A.1
update_rollover_count(&mut self, sequence_number: u16)71     pub fn update_rollover_count(&mut self, sequence_number: u16) {
72         if !self.rollover_has_processed {
73             self.rollover_has_processed = true;
74         } else if sequence_number == 0 {
75             // We exactly hit the rollover count
76 
77             // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
78             // otherwise we already incremented for disorder
79             if self.last_sequence_number > MAX_ROC_DISORDER {
80                 self.rollover_counter += 1;
81             }
82         } else if self.last_sequence_number < MAX_ROC_DISORDER
83             && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
84         {
85             // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
86             // So we fell behind, drop to account for jitter
87             self.rollover_counter -= 1;
88         } else if sequence_number < MAX_ROC_DISORDER
89             && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
90         {
91             // our current is within a MAX_ROCDISORDER of 0
92             // and our last sequence number was a high sequence number, increment to account for jitter
93             self.rollover_counter += 1;
94         }
95         self.last_sequence_number = sequence_number;
96     }
97 }
98 
99 /// Context represents a SRTP cryptographic context
100 /// Context can only be used for one-way operations
101 /// it must either used ONLY for encryption or ONLY for decryption
102 pub struct Context {
103     cipher: Box<dyn Cipher + Send>,
104 
105     srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
106     srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
107 
108     new_srtp_replay_detector: ContextOption,
109     new_srtcp_replay_detector: ContextOption,
110 }
111 
112 impl Context {
113     /// CreateContext creates a new SRTP Context
new( master_key: &[u8], master_salt: &[u8], profile: ProtectionProfile, srtp_ctx_opt: Option<ContextOption>, srtcp_ctx_opt: Option<ContextOption>, ) -> Result<Context>114     pub fn new(
115         master_key: &[u8],
116         master_salt: &[u8],
117         profile: ProtectionProfile,
118         srtp_ctx_opt: Option<ContextOption>,
119         srtcp_ctx_opt: Option<ContextOption>,
120     ) -> Result<Context> {
121         let key_len = profile.key_len();
122         let salt_len = profile.salt_len();
123 
124         if master_key.len() != key_len {
125             return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
126         } else if master_salt.len() != salt_len {
127             return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
128         }
129 
130         let cipher: Box<dyn Cipher + Send> = match profile {
131             ProtectionProfile::Aes128CmHmacSha1_80 => {
132                 Box::new(CipherAesCmHmacSha1::new(master_key, master_salt)?)
133             }
134 
135             ProtectionProfile::AeadAes128Gcm => {
136                 Box::new(CipherAeadAesGcm::new(master_key, master_salt)?)
137             }
138         };
139 
140         let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
141             ctx_opt
142         } else {
143             srtp_no_replay_protection()
144         };
145 
146         let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
147             ctx_opt
148         } else {
149             srtcp_no_replay_protection()
150         };
151 
152         Ok(Context {
153             cipher,
154             srtp_ssrc_states: HashMap::new(),
155             srtcp_ssrc_states: HashMap::new(),
156             new_srtp_replay_detector: srtp_ctx_opt,
157             new_srtcp_replay_detector: srtcp_ctx_opt,
158         })
159     }
160 
get_srtp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtpSsrcState>161     fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtpSsrcState> {
162         let s = SrtpSsrcState {
163             ssrc,
164             replay_detector: Some((self.new_srtp_replay_detector)()),
165             ..Default::default()
166         };
167 
168         self.srtp_ssrc_states.entry(ssrc).or_insert(s);
169         self.srtp_ssrc_states.get_mut(&ssrc)
170     }
171 
get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState>172     fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState> {
173         let s = SrtcpSsrcState {
174             ssrc,
175             replay_detector: Some((self.new_srtcp_replay_detector)()),
176             ..Default::default()
177         };
178         self.srtcp_ssrc_states.entry(ssrc).or_insert(s);
179         self.srtcp_ssrc_states.get_mut(&ssrc)
180     }
181 
182     /// roc returns SRTP rollover counter value of specified SSRC.
get_roc(&self, ssrc: u32) -> Option<u32>183     fn get_roc(&self, ssrc: u32) -> Option<u32> {
184         self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter)
185     }
186 
187     /// set_roc sets SRTP rollover counter value of specified SSRC.
set_roc(&mut self, ssrc: u32, roc: u32)188     fn set_roc(&mut self, ssrc: u32, roc: u32) {
189         if let Some(s) = self.get_srtp_ssrc_state(ssrc) {
190             s.rollover_counter = roc;
191         }
192     }
193 
194     /// index returns SRTCP index value of specified SSRC.
get_index(&self, ssrc: u32) -> Option<usize>195     fn get_index(&self, ssrc: u32) -> Option<usize> {
196         self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
197     }
198 
199     /// set_index sets SRTCP index value of specified SSRC.
set_index(&mut self, ssrc: u32, index: usize)200     fn set_index(&mut self, ssrc: u32, index: usize) {
201         if let Some(s) = self.get_srtcp_ssrc_state(ssrc) {
202             s.srtcp_index = index;
203         }
204     }
205 }
206