xref: /webrtc/dtls/src/state.rs (revision 2cd6154e)
1 use super::cipher_suite::*;
2 use super::conn::*;
3 use super::curve::named_curve::*;
4 use super::extension::extension_use_srtp::SrtpProtectionProfile;
5 use super::handshake::handshake_random::*;
6 use super::prf::*;
7 use crate::error::*;
8 
9 use async_trait::async_trait;
10 use serde::{Deserialize, Serialize};
11 use std::io::{BufWriter, Cursor};
12 use std::marker::{Send, Sync};
13 use std::sync::atomic::{AtomicU16, Ordering};
14 use std::sync::Arc;
15 use tokio::sync::Mutex;
16 use util::KeyingMaterialExporter;
17 use util::KeyingMaterialExporterError;
18 
19 // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
20 pub struct State {
21     pub(crate) local_epoch: Arc<AtomicU16>,
22     pub(crate) remote_epoch: Arc<AtomicU16>,
23     pub(crate) local_sequence_number: Arc<Mutex<Vec<u64>>>, // uint48
24     pub(crate) local_random: HandshakeRandom,
25     pub(crate) remote_random: HandshakeRandom,
26     pub(crate) master_secret: Vec<u8>,
27     pub(crate) cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, // nil if a cipher_suite hasn't been chosen
28 
29     pub(crate) srtp_protection_profile: SrtpProtectionProfile, // Negotiated srtp_protection_profile
30     pub peer_certificates: Vec<Vec<u8>>,
31     pub identity_hint: Vec<u8>,
32 
33     pub(crate) is_client: bool,
34 
35     pub(crate) pre_master_secret: Vec<u8>,
36     pub(crate) extended_master_secret: bool,
37 
38     pub(crate) named_curve: NamedCurve,
39     pub(crate) local_keypair: Option<NamedCurveKeypair>,
40     pub(crate) cookie: Vec<u8>,
41     pub(crate) handshake_send_sequence: isize,
42     pub(crate) handshake_recv_sequence: isize,
43     pub(crate) server_name: String,
44     pub(crate) remote_requested_certificate: bool, // Did we get a CertificateRequest
45     pub(crate) local_certificates_verify: Vec<u8>, // cache CertificateVerify
46     pub(crate) local_verify_data: Vec<u8>,         // cached VerifyData
47     pub(crate) local_key_signature: Vec<u8>,       // cached keySignature
48     pub(crate) peer_certificates_verified: bool,
49     //pub(crate) replay_detector: Vec<Box<dyn ReplayDetector + Send + Sync>>,
50 }
51 
52 #[derive(Serialize, Deserialize, PartialEq, Debug)]
53 struct SerializedState {
54     local_epoch: u16,
55     remote_epoch: u16,
56     local_random: [u8; HANDSHAKE_RANDOM_LENGTH],
57     remote_random: [u8; HANDSHAKE_RANDOM_LENGTH],
58     cipher_suite_id: u16,
59     master_secret: Vec<u8>,
60     sequence_number: u64,
61     srtp_protection_profile: u16,
62     peer_certificates: Vec<Vec<u8>>,
63     identity_hint: Vec<u8>,
64     is_client: bool,
65 }
66 
67 impl Default for State {
default() -> Self68     fn default() -> Self {
69         State {
70             local_epoch: Arc::new(AtomicU16::new(0)),
71             remote_epoch: Arc::new(AtomicU16::new(0)),
72             local_sequence_number: Arc::new(Mutex::new(vec![])),
73             local_random: HandshakeRandom::default(),
74             remote_random: HandshakeRandom::default(),
75             master_secret: vec![],
76             cipher_suite: Arc::new(Mutex::new(None)), // nil if a cipher_suite hasn't been chosen
77 
78             srtp_protection_profile: SrtpProtectionProfile::Unsupported, // Negotiated srtp_protection_profile
79             peer_certificates: vec![],
80             identity_hint: vec![],
81 
82             is_client: false,
83 
84             pre_master_secret: vec![],
85             extended_master_secret: false,
86 
87             named_curve: NamedCurve::Unsupported,
88             local_keypair: None,
89             cookie: vec![],
90             handshake_send_sequence: 0,
91             handshake_recv_sequence: 0,
92             server_name: "".to_string(),
93             remote_requested_certificate: false, // Did we get a CertificateRequest
94             local_certificates_verify: vec![],   // cache CertificateVerify
95             local_verify_data: vec![],           // cached VerifyData
96             local_key_signature: vec![],         // cached keySignature
97             peer_certificates_verified: false,
98             //replay_detector: vec![],
99         }
100     }
101 }
102 
103 impl State {
clone(&self) -> Self104     pub(crate) async fn clone(&self) -> Self {
105         let mut state = State::default();
106 
107         if let Ok(serialized) = self.serialize().await {
108             let _ = state.deserialize(&serialized).await;
109         }
110 
111         state
112     }
113 
serialize(&self) -> Result<SerializedState>114     async fn serialize(&self) -> Result<SerializedState> {
115         let mut local_rand = vec![];
116         {
117             let mut writer = BufWriter::<&mut Vec<u8>>::new(local_rand.as_mut());
118             self.local_random.marshal(&mut writer)?;
119         }
120         let mut remote_rand = vec![];
121         {
122             let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_rand.as_mut());
123             self.remote_random.marshal(&mut writer)?;
124         }
125 
126         let mut local_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
127         let mut remote_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
128 
129         local_random.copy_from_slice(&local_rand);
130         remote_random.copy_from_slice(&remote_rand);
131 
132         let local_epoch = self.local_epoch.load(Ordering::SeqCst);
133         let remote_epoch = self.remote_epoch.load(Ordering::SeqCst);
134         let sequence_number = {
135             let lsn = self.local_sequence_number.lock().await;
136             lsn[local_epoch as usize]
137         };
138         let cipher_suite_id = {
139             let cipher_suite = self.cipher_suite.lock().await;
140             match &*cipher_suite {
141                 Some(cipher_suite) => cipher_suite.id() as u16,
142                 None => return Err(Error::ErrCipherSuiteUnset),
143             }
144         };
145 
146         Ok(SerializedState {
147             local_epoch,
148             remote_epoch,
149             local_random,
150             remote_random,
151             cipher_suite_id,
152             master_secret: self.master_secret.clone(),
153             sequence_number,
154             srtp_protection_profile: self.srtp_protection_profile as u16,
155             peer_certificates: self.peer_certificates.clone(),
156             identity_hint: self.identity_hint.clone(),
157             is_client: self.is_client,
158         })
159     }
160 
deserialize(&mut self, serialized: &SerializedState) -> Result<()>161     async fn deserialize(&mut self, serialized: &SerializedState) -> Result<()> {
162         // Set epoch values
163         self.local_epoch
164             .store(serialized.local_epoch, Ordering::SeqCst);
165         self.remote_epoch
166             .store(serialized.remote_epoch, Ordering::SeqCst);
167         {
168             let mut lsn = self.local_sequence_number.lock().await;
169             while lsn.len() <= serialized.local_epoch as usize {
170                 lsn.push(0);
171             }
172             lsn[serialized.local_epoch as usize] = serialized.sequence_number;
173         }
174 
175         // Set random values
176         let mut reader = Cursor::new(&serialized.local_random);
177         self.local_random = HandshakeRandom::unmarshal(&mut reader)?;
178 
179         let mut reader = Cursor::new(&serialized.remote_random);
180         self.remote_random = HandshakeRandom::unmarshal(&mut reader)?;
181 
182         self.is_client = serialized.is_client;
183 
184         // Set master secret
185         self.master_secret = serialized.master_secret.clone();
186 
187         // Set cipher suite
188         self.cipher_suite = Arc::new(Mutex::new(Some(cipher_suite_for_id(
189             serialized.cipher_suite_id.into(),
190         )?)));
191 
192         self.srtp_protection_profile = serialized.srtp_protection_profile.into();
193 
194         // Set remote certificate
195         self.peer_certificates = serialized.peer_certificates.clone();
196         self.identity_hint = serialized.identity_hint.clone();
197 
198         Ok(())
199     }
200 
init_cipher_suite(&mut self) -> Result<()>201     pub async fn init_cipher_suite(&mut self) -> Result<()> {
202         let mut cipher_suite = self.cipher_suite.lock().await;
203         if let Some(cipher_suite) = &mut *cipher_suite {
204             if cipher_suite.is_initialized() {
205                 return Ok(());
206             }
207 
208             let mut local_random = vec![];
209             {
210                 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
211                 self.local_random.marshal(&mut writer)?;
212             }
213             let mut remote_random = vec![];
214             {
215                 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
216                 self.remote_random.marshal(&mut writer)?;
217             }
218 
219             if self.is_client {
220                 cipher_suite.init(&self.master_secret, &local_random, &remote_random, true)
221             } else {
222                 cipher_suite.init(&self.master_secret, &remote_random, &local_random, false)
223             }
224         } else {
225             Err(Error::ErrCipherSuiteUnset)
226         }
227     }
228 
229     // marshal_binary is a binary.BinaryMarshaler.marshal_binary implementation
marshal_binary(&self) -> Result<Vec<u8>>230     pub async fn marshal_binary(&self) -> Result<Vec<u8>> {
231         let serialized = self.serialize().await?;
232 
233         match bincode::serialize(&serialized) {
234             Ok(enc) => Ok(enc),
235             Err(err) => Err(Error::Other(err.to_string())),
236         }
237     }
238 
239     // unmarshal_binary is a binary.BinaryUnmarshaler.unmarshal_binary implementation
unmarshal_binary(&mut self, data: &[u8]) -> Result<()>240     pub async fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> {
241         let serialized: SerializedState = match bincode::deserialize(data) {
242             Ok(dec) => dec,
243             Err(err) => return Err(Error::Other(err.to_string())),
244         };
245         self.deserialize(&serialized).await?;
246         self.init_cipher_suite().await?;
247 
248         Ok(())
249     }
250 }
251 
252 #[async_trait]
253 impl KeyingMaterialExporter for State {
254     /// export_keying_material returns length bytes of exported key material in a new
255     /// slice as defined in RFC 5705.
256     /// This allows protocols to use DTLS for key establishment, but
257     /// then use some of the keying material for their own purposes
export_keying_material( &self, label: &str, context: &[u8], length: usize, ) -> std::result::Result<Vec<u8>, KeyingMaterialExporterError>258     async fn export_keying_material(
259         &self,
260         label: &str,
261         context: &[u8],
262         length: usize,
263     ) -> std::result::Result<Vec<u8>, KeyingMaterialExporterError> {
264         use KeyingMaterialExporterError::*;
265 
266         if self.local_epoch.load(Ordering::SeqCst) == 0 {
267             return Err(HandshakeInProgress);
268         } else if !context.is_empty() {
269             return Err(ContextUnsupported);
270         } else if INVALID_KEYING_LABELS.contains(&label) {
271             return Err(ReservedExportKeyingMaterial);
272         }
273 
274         let mut local_random = vec![];
275         {
276             let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
277             self.local_random.marshal(&mut writer)?;
278         }
279         let mut remote_random = vec![];
280         {
281             let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
282             self.remote_random.marshal(&mut writer)?;
283         }
284 
285         let mut seed = label.as_bytes().to_vec();
286         if self.is_client {
287             seed.extend_from_slice(&local_random);
288             seed.extend_from_slice(&remote_random);
289         } else {
290             seed.extend_from_slice(&remote_random);
291             seed.extend_from_slice(&local_random);
292         }
293 
294         let cipher_suite = self.cipher_suite.lock().await;
295         if let Some(cipher_suite) = &*cipher_suite {
296             match prf_p_hash(&self.master_secret, &seed, length, cipher_suite.hash_func()) {
297                 Ok(v) => Ok(v),
298                 Err(err) => Err(Hash(err.to_string())),
299             }
300         } else {
301             Err(CipherSuiteUnset)
302         }
303     }
304 }
305