1 #[cfg(test)] 2 mod context_test; 3 #[cfg(test)] 4 mod srtcp_test; 5 #[cfg(test)] 6 mod srtp_test; 7 8 use crate::error::Result; 9 use crate::{ 10 cipher::cipher_aead_aes_gcm::*, cipher::cipher_aes_cm_hmac_sha1::*, cipher::*, error::Error, 11 option::*, protection_profile::*, 12 }; 13 14 use std::collections::HashMap; 15 use util::replay_detector::*; 16 17 pub mod srtcp; 18 pub mod srtp; 19 20 const MAX_ROC_DISORDER: u16 = 100; 21 22 /// Encrypt/Decrypt state for a single SRTP SSRC 23 #[derive(Default)] 24 pub(crate) struct SrtpSsrcState { 25 ssrc: u32, 26 rollover_counter: u32, 27 rollover_has_processed: bool, 28 last_sequence_number: u16, 29 replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>, 30 } 31 32 /// Encrypt/Decrypt state for a single SRTCP SSRC 33 #[derive(Default)] 34 pub(crate) struct SrtcpSsrcState { 35 srtcp_index: usize, 36 ssrc: u32, 37 replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>, 38 } 39 40 impl SrtpSsrcState { next_rollover_count(&self, sequence_number: u16) -> u3241 pub fn next_rollover_count(&self, sequence_number: u16) -> u32 { 42 let mut roc = self.rollover_counter; 43 44 if !self.rollover_has_processed { 45 } else if sequence_number == 0 { 46 // We exactly hit the rollover count 47 48 // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER 49 // otherwise we already incremented for disorder 50 if self.last_sequence_number > MAX_ROC_DISORDER { 51 roc += 1; 52 } 53 } else if self.last_sequence_number < MAX_ROC_DISORDER 54 && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) 55 { 56 // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max 57 // So we fell behind, drop to account for jitter 58 roc -= 1; 59 } else if sequence_number < MAX_ROC_DISORDER 60 && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) 61 { 62 // our current is within a MAX_ROCDISORDER of 0 63 // and our last sequence number was a high sequence number, increment to account for jitter 64 roc += 1; 65 } 66 67 roc 68 } 69 70 /// https://tools.ietf.org/html/rfc3550#appendix-A.1 update_rollover_count(&mut self, sequence_number: u16)71 pub fn update_rollover_count(&mut self, sequence_number: u16) { 72 if !self.rollover_has_processed { 73 self.rollover_has_processed = true; 74 } else if sequence_number == 0 { 75 // We exactly hit the rollover count 76 77 // Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER 78 // otherwise we already incremented for disorder 79 if self.last_sequence_number > MAX_ROC_DISORDER { 80 self.rollover_counter += 1; 81 } 82 } else if self.last_sequence_number < MAX_ROC_DISORDER 83 && sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) 84 { 85 // Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max 86 // So we fell behind, drop to account for jitter 87 self.rollover_counter -= 1; 88 } else if sequence_number < MAX_ROC_DISORDER 89 && self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER) 90 { 91 // our current is within a MAX_ROCDISORDER of 0 92 // and our last sequence number was a high sequence number, increment to account for jitter 93 self.rollover_counter += 1; 94 } 95 self.last_sequence_number = sequence_number; 96 } 97 } 98 99 /// Context represents a SRTP cryptographic context 100 /// Context can only be used for one-way operations 101 /// it must either used ONLY for encryption or ONLY for decryption 102 pub struct Context { 103 cipher: Box<dyn Cipher + Send>, 104 105 srtp_ssrc_states: HashMap<u32, SrtpSsrcState>, 106 srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>, 107 108 new_srtp_replay_detector: ContextOption, 109 new_srtcp_replay_detector: ContextOption, 110 } 111 112 impl Context { 113 /// CreateContext creates a new SRTP Context new( master_key: &[u8], master_salt: &[u8], profile: ProtectionProfile, srtp_ctx_opt: Option<ContextOption>, srtcp_ctx_opt: Option<ContextOption>, ) -> Result<Context>114 pub fn new( 115 master_key: &[u8], 116 master_salt: &[u8], 117 profile: ProtectionProfile, 118 srtp_ctx_opt: Option<ContextOption>, 119 srtcp_ctx_opt: Option<ContextOption>, 120 ) -> Result<Context> { 121 let key_len = profile.key_len(); 122 let salt_len = profile.salt_len(); 123 124 if master_key.len() != key_len { 125 return Err(Error::SrtpMasterKeyLength(key_len, master_key.len())); 126 } else if master_salt.len() != salt_len { 127 return Err(Error::SrtpSaltLength(salt_len, master_salt.len())); 128 } 129 130 let cipher: Box<dyn Cipher + Send> = match profile { 131 ProtectionProfile::Aes128CmHmacSha1_80 => { 132 Box::new(CipherAesCmHmacSha1::new(master_key, master_salt)?) 133 } 134 135 ProtectionProfile::AeadAes128Gcm => { 136 Box::new(CipherAeadAesGcm::new(master_key, master_salt)?) 137 } 138 }; 139 140 let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt { 141 ctx_opt 142 } else { 143 srtp_no_replay_protection() 144 }; 145 146 let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt { 147 ctx_opt 148 } else { 149 srtcp_no_replay_protection() 150 }; 151 152 Ok(Context { 153 cipher, 154 srtp_ssrc_states: HashMap::new(), 155 srtcp_ssrc_states: HashMap::new(), 156 new_srtp_replay_detector: srtp_ctx_opt, 157 new_srtcp_replay_detector: srtcp_ctx_opt, 158 }) 159 } 160 get_srtp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtpSsrcState>161 fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtpSsrcState> { 162 let s = SrtpSsrcState { 163 ssrc, 164 replay_detector: Some((self.new_srtp_replay_detector)()), 165 ..Default::default() 166 }; 167 168 self.srtp_ssrc_states.entry(ssrc).or_insert(s); 169 self.srtp_ssrc_states.get_mut(&ssrc) 170 } 171 get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState>172 fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState> { 173 let s = SrtcpSsrcState { 174 ssrc, 175 replay_detector: Some((self.new_srtcp_replay_detector)()), 176 ..Default::default() 177 }; 178 self.srtcp_ssrc_states.entry(ssrc).or_insert(s); 179 self.srtcp_ssrc_states.get_mut(&ssrc) 180 } 181 182 /// roc returns SRTP rollover counter value of specified SSRC. get_roc(&self, ssrc: u32) -> Option<u32>183 fn get_roc(&self, ssrc: u32) -> Option<u32> { 184 self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter) 185 } 186 187 /// set_roc sets SRTP rollover counter value of specified SSRC. set_roc(&mut self, ssrc: u32, roc: u32)188 fn set_roc(&mut self, ssrc: u32, roc: u32) { 189 if let Some(s) = self.get_srtp_ssrc_state(ssrc) { 190 s.rollover_counter = roc; 191 } 192 } 193 194 /// index returns SRTCP index value of specified SSRC. get_index(&self, ssrc: u32) -> Option<usize>195 fn get_index(&self, ssrc: u32) -> Option<usize> { 196 self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index) 197 } 198 199 /// set_index sets SRTCP index value of specified SSRC. set_index(&mut self, ssrc: u32, index: usize)200 fn set_index(&mut self, ssrc: u32, index: usize) { 201 if let Some(s) = self.get_srtcp_ssrc_state(ssrc) { 202 s.srtcp_index = index; 203 } 204 } 205 } 206