1 use super::cipher_suite::*; 2 use super::conn::*; 3 use super::curve::named_curve::*; 4 use super::extension::extension_use_srtp::SrtpProtectionProfile; 5 use super::handshake::handshake_random::*; 6 use super::prf::*; 7 use crate::error::*; 8 9 use async_trait::async_trait; 10 use serde::{Deserialize, Serialize}; 11 use std::io::{BufWriter, Cursor}; 12 use std::marker::{Send, Sync}; 13 use std::sync::atomic::{AtomicU16, Ordering}; 14 use std::sync::Arc; 15 use tokio::sync::Mutex; 16 use util::KeyingMaterialExporter; 17 use util::KeyingMaterialExporterError; 18 19 // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler 20 pub struct State { 21 pub(crate) local_epoch: Arc<AtomicU16>, 22 pub(crate) remote_epoch: Arc<AtomicU16>, 23 pub(crate) local_sequence_number: Arc<Mutex<Vec<u64>>>, // uint48 24 pub(crate) local_random: HandshakeRandom, 25 pub(crate) remote_random: HandshakeRandom, 26 pub(crate) master_secret: Vec<u8>, 27 pub(crate) cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, // nil if a cipher_suite hasn't been chosen 28 29 pub(crate) srtp_protection_profile: SrtpProtectionProfile, // Negotiated srtp_protection_profile 30 pub peer_certificates: Vec<Vec<u8>>, 31 pub identity_hint: Vec<u8>, 32 33 pub(crate) is_client: bool, 34 35 pub(crate) pre_master_secret: Vec<u8>, 36 pub(crate) extended_master_secret: bool, 37 38 pub(crate) named_curve: NamedCurve, 39 pub(crate) local_keypair: Option<NamedCurveKeypair>, 40 pub(crate) cookie: Vec<u8>, 41 pub(crate) handshake_send_sequence: isize, 42 pub(crate) handshake_recv_sequence: isize, 43 pub(crate) server_name: String, 44 pub(crate) remote_requested_certificate: bool, // Did we get a CertificateRequest 45 pub(crate) local_certificates_verify: Vec<u8>, // cache CertificateVerify 46 pub(crate) local_verify_data: Vec<u8>, // cached VerifyData 47 pub(crate) local_key_signature: Vec<u8>, // cached keySignature 48 pub(crate) peer_certificates_verified: bool, 49 //pub(crate) replay_detector: Vec<Box<dyn ReplayDetector + Send + Sync>>, 50 } 51 52 #[derive(Serialize, Deserialize, PartialEq, Debug)] 53 struct SerializedState { 54 local_epoch: u16, 55 remote_epoch: u16, 56 local_random: [u8; HANDSHAKE_RANDOM_LENGTH], 57 remote_random: [u8; HANDSHAKE_RANDOM_LENGTH], 58 cipher_suite_id: u16, 59 master_secret: Vec<u8>, 60 sequence_number: u64, 61 srtp_protection_profile: u16, 62 peer_certificates: Vec<Vec<u8>>, 63 identity_hint: Vec<u8>, 64 is_client: bool, 65 } 66 67 impl Default for State { default() -> Self68 fn default() -> Self { 69 State { 70 local_epoch: Arc::new(AtomicU16::new(0)), 71 remote_epoch: Arc::new(AtomicU16::new(0)), 72 local_sequence_number: Arc::new(Mutex::new(vec![])), 73 local_random: HandshakeRandom::default(), 74 remote_random: HandshakeRandom::default(), 75 master_secret: vec![], 76 cipher_suite: Arc::new(Mutex::new(None)), // nil if a cipher_suite hasn't been chosen 77 78 srtp_protection_profile: SrtpProtectionProfile::Unsupported, // Negotiated srtp_protection_profile 79 peer_certificates: vec![], 80 identity_hint: vec![], 81 82 is_client: false, 83 84 pre_master_secret: vec![], 85 extended_master_secret: false, 86 87 named_curve: NamedCurve::Unsupported, 88 local_keypair: None, 89 cookie: vec![], 90 handshake_send_sequence: 0, 91 handshake_recv_sequence: 0, 92 server_name: "".to_string(), 93 remote_requested_certificate: false, // Did we get a CertificateRequest 94 local_certificates_verify: vec![], // cache CertificateVerify 95 local_verify_data: vec![], // cached VerifyData 96 local_key_signature: vec![], // cached keySignature 97 peer_certificates_verified: false, 98 //replay_detector: vec![], 99 } 100 } 101 } 102 103 impl State { clone(&self) -> Self104 pub(crate) async fn clone(&self) -> Self { 105 let mut state = State::default(); 106 107 if let Ok(serialized) = self.serialize().await { 108 let _ = state.deserialize(&serialized).await; 109 } 110 111 state 112 } 113 serialize(&self) -> Result<SerializedState>114 async fn serialize(&self) -> Result<SerializedState> { 115 let mut local_rand = vec![]; 116 { 117 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_rand.as_mut()); 118 self.local_random.marshal(&mut writer)?; 119 } 120 let mut remote_rand = vec![]; 121 { 122 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_rand.as_mut()); 123 self.remote_random.marshal(&mut writer)?; 124 } 125 126 let mut local_random = [0u8; HANDSHAKE_RANDOM_LENGTH]; 127 let mut remote_random = [0u8; HANDSHAKE_RANDOM_LENGTH]; 128 129 local_random.copy_from_slice(&local_rand); 130 remote_random.copy_from_slice(&remote_rand); 131 132 let local_epoch = self.local_epoch.load(Ordering::SeqCst); 133 let remote_epoch = self.remote_epoch.load(Ordering::SeqCst); 134 let sequence_number = { 135 let lsn = self.local_sequence_number.lock().await; 136 lsn[local_epoch as usize] 137 }; 138 let cipher_suite_id = { 139 let cipher_suite = self.cipher_suite.lock().await; 140 match &*cipher_suite { 141 Some(cipher_suite) => cipher_suite.id() as u16, 142 None => return Err(Error::ErrCipherSuiteUnset), 143 } 144 }; 145 146 Ok(SerializedState { 147 local_epoch, 148 remote_epoch, 149 local_random, 150 remote_random, 151 cipher_suite_id, 152 master_secret: self.master_secret.clone(), 153 sequence_number, 154 srtp_protection_profile: self.srtp_protection_profile as u16, 155 peer_certificates: self.peer_certificates.clone(), 156 identity_hint: self.identity_hint.clone(), 157 is_client: self.is_client, 158 }) 159 } 160 deserialize(&mut self, serialized: &SerializedState) -> Result<()>161 async fn deserialize(&mut self, serialized: &SerializedState) -> Result<()> { 162 // Set epoch values 163 self.local_epoch 164 .store(serialized.local_epoch, Ordering::SeqCst); 165 self.remote_epoch 166 .store(serialized.remote_epoch, Ordering::SeqCst); 167 { 168 let mut lsn = self.local_sequence_number.lock().await; 169 while lsn.len() <= serialized.local_epoch as usize { 170 lsn.push(0); 171 } 172 lsn[serialized.local_epoch as usize] = serialized.sequence_number; 173 } 174 175 // Set random values 176 let mut reader = Cursor::new(&serialized.local_random); 177 self.local_random = HandshakeRandom::unmarshal(&mut reader)?; 178 179 let mut reader = Cursor::new(&serialized.remote_random); 180 self.remote_random = HandshakeRandom::unmarshal(&mut reader)?; 181 182 self.is_client = serialized.is_client; 183 184 // Set master secret 185 self.master_secret = serialized.master_secret.clone(); 186 187 // Set cipher suite 188 self.cipher_suite = Arc::new(Mutex::new(Some(cipher_suite_for_id( 189 serialized.cipher_suite_id.into(), 190 )?))); 191 192 self.srtp_protection_profile = serialized.srtp_protection_profile.into(); 193 194 // Set remote certificate 195 self.peer_certificates = serialized.peer_certificates.clone(); 196 self.identity_hint = serialized.identity_hint.clone(); 197 198 Ok(()) 199 } 200 init_cipher_suite(&mut self) -> Result<()>201 pub async fn init_cipher_suite(&mut self) -> Result<()> { 202 let mut cipher_suite = self.cipher_suite.lock().await; 203 if let Some(cipher_suite) = &mut *cipher_suite { 204 if cipher_suite.is_initialized() { 205 return Ok(()); 206 } 207 208 let mut local_random = vec![]; 209 { 210 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut()); 211 self.local_random.marshal(&mut writer)?; 212 } 213 let mut remote_random = vec![]; 214 { 215 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut()); 216 self.remote_random.marshal(&mut writer)?; 217 } 218 219 if self.is_client { 220 cipher_suite.init(&self.master_secret, &local_random, &remote_random, true) 221 } else { 222 cipher_suite.init(&self.master_secret, &remote_random, &local_random, false) 223 } 224 } else { 225 Err(Error::ErrCipherSuiteUnset) 226 } 227 } 228 229 // marshal_binary is a binary.BinaryMarshaler.marshal_binary implementation marshal_binary(&self) -> Result<Vec<u8>>230 pub async fn marshal_binary(&self) -> Result<Vec<u8>> { 231 let serialized = self.serialize().await?; 232 233 match bincode::serialize(&serialized) { 234 Ok(enc) => Ok(enc), 235 Err(err) => Err(Error::Other(err.to_string())), 236 } 237 } 238 239 // unmarshal_binary is a binary.BinaryUnmarshaler.unmarshal_binary implementation unmarshal_binary(&mut self, data: &[u8]) -> Result<()>240 pub async fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> { 241 let serialized: SerializedState = match bincode::deserialize(data) { 242 Ok(dec) => dec, 243 Err(err) => return Err(Error::Other(err.to_string())), 244 }; 245 self.deserialize(&serialized).await?; 246 self.init_cipher_suite().await?; 247 248 Ok(()) 249 } 250 } 251 252 #[async_trait] 253 impl KeyingMaterialExporter for State { 254 /// export_keying_material returns length bytes of exported key material in a new 255 /// slice as defined in RFC 5705. 256 /// This allows protocols to use DTLS for key establishment, but 257 /// then use some of the keying material for their own purposes export_keying_material( &self, label: &str, context: &[u8], length: usize, ) -> std::result::Result<Vec<u8>, KeyingMaterialExporterError>258 async fn export_keying_material( 259 &self, 260 label: &str, 261 context: &[u8], 262 length: usize, 263 ) -> std::result::Result<Vec<u8>, KeyingMaterialExporterError> { 264 use KeyingMaterialExporterError::*; 265 266 if self.local_epoch.load(Ordering::SeqCst) == 0 { 267 return Err(HandshakeInProgress); 268 } else if !context.is_empty() { 269 return Err(ContextUnsupported); 270 } else if INVALID_KEYING_LABELS.contains(&label) { 271 return Err(ReservedExportKeyingMaterial); 272 } 273 274 let mut local_random = vec![]; 275 { 276 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut()); 277 self.local_random.marshal(&mut writer)?; 278 } 279 let mut remote_random = vec![]; 280 { 281 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut()); 282 self.remote_random.marshal(&mut writer)?; 283 } 284 285 let mut seed = label.as_bytes().to_vec(); 286 if self.is_client { 287 seed.extend_from_slice(&local_random); 288 seed.extend_from_slice(&remote_random); 289 } else { 290 seed.extend_from_slice(&remote_random); 291 seed.extend_from_slice(&local_random); 292 } 293 294 let cipher_suite = self.cipher_suite.lock().await; 295 if let Some(cipher_suite) = &*cipher_suite { 296 match prf_p_hash(&self.master_secret, &seed, length, cipher_suite.hash_func()) { 297 Ok(v) => Ok(v), 298 Err(err) => Err(Hash(err.to_string())), 299 } 300 } else { 301 Err(CipherSuiteUnset) 302 } 303 } 304 } 305