xref: /webrtc/dtls/src/handshaker.rs (revision 03a147ee)
1 use crate::cipher_suite::*;
2 use crate::config::*;
3 use crate::conn::*;
4 use crate::content::*;
5 use crate::crypto::*;
6 use crate::error::*;
7 use crate::extension::extension_use_srtp::*;
8 use crate::signature_hash_algorithm::*;
9 
10 use log::*;
11 use std::collections::HashMap;
12 use std::fmt;
13 use std::sync::Arc;
14 
15 //use std::io::BufWriter;
16 
17 // [RFC6347 Section-4.2.4]
18 //                      +-----------+
19 //                +---> | PREPARING | <--------------------+
20 //                |     +-----------+                      |
21 //                |           |                            |
22 //                |           | Buffer next flight         |
23 //                |           |                            |
24 //                |          \|/                           |
25 //                |     +-----------+                      |
26 //                |     |  SENDING  |<------------------+  | Send
27 //                |     +-----------+                   |  | HelloRequest
28 //        Receive |           |                         |  |
29 //           next |           | Send flight             |  | or
30 //         flight |  +--------+                         |  |
31 //                |  |        | Set retransmit timer    |  | Receive
32 //                |  |       \|/                        |  | HelloRequest
33 //                |  |  +-----------+                   |  | Send
34 //                +--)--|  WAITING  |-------------------+  | ClientHello
35 //                |  |  +-----------+   Timer expires   |  |
36 //                |  |         |                        |  |
37 //                |  |         +------------------------+  |
38 //        Receive |  | Send           Read retransmit      |
39 //           last |  | last                                |
40 //         flight |  | flight                              |
41 //                |  |                                     |
42 //               \|/\|/                                    |
43 //            +-----------+                                |
44 //            | FINISHED  | -------------------------------+
45 //            +-----------+
46 //                 |  /|\
47 //                 |   |
48 //                 +---+
49 //              Read retransmit
50 //           Retransmit last flight
51 
52 #[derive(Copy, Clone, PartialEq)]
53 pub(crate) enum HandshakeState {
54     Errored,
55     Preparing,
56     Sending,
57     Waiting,
58     Finished,
59 }
60 
61 impl fmt::Display for HandshakeState {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result62     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63         match *self {
64             HandshakeState::Errored => write!(f, "Errored"),
65             HandshakeState::Preparing => write!(f, "Preparing"),
66             HandshakeState::Sending => write!(f, "Sending"),
67             HandshakeState::Waiting => write!(f, "Waiting"),
68             HandshakeState::Finished => write!(f, "Finished"),
69         }
70     }
71 }
72 
73 pub(crate) type VerifyPeerCertificateFn =
74     Arc<dyn (Fn(&[Vec<u8>], &[rustls::Certificate]) -> Result<()>) + Send + Sync>;
75 
76 pub(crate) struct HandshakeConfig {
77     pub(crate) local_psk_callback: Option<PskCallback>,
78     pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
79     pub(crate) local_cipher_suites: Vec<CipherSuiteId>, // Available CipherSuites
80     pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>, // Available signature schemes
81     pub(crate) extended_master_secret: ExtendedMasterSecretType, // Policy for the Extended Master Support extension
82     pub(crate) local_srtp_protection_profiles: Vec<SrtpProtectionProfile>, // Available SRTPProtectionProfiles, if empty no SRTP support
83     pub(crate) server_name: String,
84     pub(crate) client_auth: ClientAuthType, // If we are a client should we request a client certificate
85     pub(crate) local_certificates: Vec<Certificate>,
86     pub(crate) name_to_certificate: HashMap<String, Certificate>,
87     pub(crate) insecure_skip_verify: bool,
88     pub(crate) insecure_verification: bool,
89     pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
90     pub(crate) roots_cas: rustls::RootCertStore,
91     pub(crate) server_cert_verifier: Arc<dyn rustls::ServerCertVerifier>,
92     pub(crate) client_cert_verifier: Option<Arc<dyn rustls::ClientCertVerifier>>,
93     pub(crate) retransmit_interval: tokio::time::Duration,
94     pub(crate) initial_epoch: u16,
95     //log           logging.LeveledLogger
96     //mu sync.Mutex
97 }
98 
99 impl Default for HandshakeConfig {
default() -> Self100     fn default() -> Self {
101         HandshakeConfig {
102             local_psk_callback: None,
103             local_psk_identity_hint: None,
104             local_cipher_suites: vec![],
105             local_signature_schemes: vec![],
106             extended_master_secret: ExtendedMasterSecretType::Disable,
107             local_srtp_protection_profiles: vec![],
108             server_name: String::new(),
109             client_auth: ClientAuthType::NoClientCert,
110             local_certificates: vec![],
111             name_to_certificate: HashMap::new(),
112             insecure_skip_verify: false,
113             insecure_verification: false,
114             verify_peer_certificate: None,
115             roots_cas: rustls::RootCertStore::empty(),
116             server_cert_verifier: Arc::new(rustls::WebPKIVerifier::new()),
117             client_cert_verifier: None,
118             retransmit_interval: tokio::time::Duration::from_secs(0),
119             initial_epoch: 0,
120         }
121     }
122 }
123 
124 impl HandshakeConfig {
get_certificate(&self, server_name: &str) -> Result<Certificate>125     pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate> {
126         //TODO
127         /*if self.name_to_certificate.is_empty() {
128             let mut name_to_certificate = HashMap::new();
129             for cert in &self.local_certificates {
130                 if let Ok((_rem, x509_cert)) = x509_parser::parse_x509_der(&cert.certificate) {
131                     if let Some(a) = x509_cert.tbs_certificate.subject.iter_common_name().next() {
132                         let common_name = match a.attr_value.as_str() {
133                             Ok(cn) => cn.to_lowercase(),
134                             Err(err) => return Err(Error::new(err.to_string())),
135                         };
136                         name_to_certificate.insert(common_name, cert.clone());
137                     }
138                     if let Some((_, sans)) = x509_cert.tbs_certificate.subject_alternative_name() {
139                         for gn in &sans.general_names {
140                             match gn {
141                                 x509_parser::extensions::GeneralName::DNSName(san) => {
142                                     let san = san.to_lowercase();
143                                     name_to_certificate.insert(san, cert.clone());
144                                 }
145                                 _ => {}
146                             }
147                         }
148                     }
149                 } else {
150                     continue;
151                 }
152             }
153             self.name_to_certificate = name_to_certificate;
154         }*/
155 
156         if self.local_certificates.is_empty() {
157             return Err(Error::ErrNoCertificates);
158         }
159 
160         if self.local_certificates.len() == 1 {
161             // There's only one choice, so no point doing any work.
162             return Ok(self.local_certificates[0].clone());
163         }
164 
165         if server_name.is_empty() {
166             return Ok(self.local_certificates[0].clone());
167         }
168 
169         let lower = server_name.to_lowercase();
170         let name = lower.trim_end_matches('.');
171 
172         if let Some(cert) = self.name_to_certificate.get(name) {
173             return Ok(cert.clone());
174         }
175 
176         // try replacing labels in the name with wildcards until we get a
177         // match.
178         let mut labels: Vec<&str> = name.split_terminator('.').collect();
179         for i in 0..labels.len() {
180             labels[i] = "*";
181             let candidate = labels.join(".");
182             if let Some(cert) = self.name_to_certificate.get(&candidate) {
183                 return Ok(cert.clone());
184             }
185         }
186 
187         // If nothing matches, return the first certificate.
188         Ok(self.local_certificates[0].clone())
189     }
190 }
191 
srv_cli_str(is_client: bool) -> String192 pub(crate) fn srv_cli_str(is_client: bool) -> String {
193     if is_client {
194         return "client".to_owned();
195     }
196     "server".to_owned()
197 }
198 
199 impl DTLSConn {
handshake(&mut self, mut state: HandshakeState) -> Result<()>200     pub(crate) async fn handshake(&mut self, mut state: HandshakeState) -> Result<()> {
201         loop {
202             trace!(
203                 "[handshake:{}] {}: {}",
204                 srv_cli_str(self.state.is_client),
205                 self.current_flight.to_string(),
206                 state.to_string()
207             );
208 
209             if state == HandshakeState::Finished && !self.is_handshake_completed_successfully() {
210                 self.set_handshake_completed_successfully();
211                 self.handshake_done_tx.take(); // drop it by take
212                 return Ok(());
213             }
214 
215             state = match state {
216                 HandshakeState::Preparing => self.prepare().await?,
217                 HandshakeState::Sending => self.send().await?,
218                 HandshakeState::Waiting => self.wait().await?,
219                 HandshakeState::Finished => self.finish().await?,
220                 _ => return Err(Error::ErrInvalidFsmTransition),
221             };
222         }
223     }
224 
prepare(&mut self) -> Result<HandshakeState>225     async fn prepare(&mut self) -> Result<HandshakeState> {
226         self.flights = None;
227 
228         // Prepare flights
229         self.retransmit = self.current_flight.has_retransmit();
230 
231         let result = self
232             .current_flight
233             .generate(&mut self.state, &self.cache, &self.cfg)
234             .await;
235 
236         match result {
237             Err((a, mut err)) => {
238                 if let Some(a) = a {
239                     let alert_err = self.notify(a.alert_level, a.alert_description).await;
240 
241                     if let Err(alert_err) = alert_err {
242                         if err.is_some() {
243                             err = Some(alert_err);
244                         }
245                     }
246                 }
247                 if let Some(err) = err {
248                     return Err(err);
249                 }
250             }
251             Ok(pkts) => {
252                 /*if !pkts.is_empty() {
253                     let mut s = vec![];
254                     {
255                         let mut writer = BufWriter::<&mut Vec<u8>>::new(s.as_mut());
256                         pkts[0].record.content.marshal(&mut writer)?;
257                     }
258                     trace!(
259                         "[handshake:{}] {}: {:?}",
260                         srv_cli_str(self.state.is_client),
261                         self.current_flight.to_string(),
262                         s,
263                     );
264                 }*/
265                 self.flights = Some(pkts)
266             }
267         };
268 
269         let epoch = self.cfg.initial_epoch;
270         let mut next_epoch = epoch;
271         if let Some(pkts) = &mut self.flights {
272             for p in pkts {
273                 p.record.record_layer_header.epoch += epoch;
274                 if p.record.record_layer_header.epoch > next_epoch {
275                     next_epoch = p.record.record_layer_header.epoch;
276                 }
277                 if let Content::Handshake(h) = &mut p.record.content {
278                     h.handshake_header.message_sequence = self.state.handshake_send_sequence as u16;
279                     self.state.handshake_send_sequence += 1;
280                 }
281             }
282         }
283         if epoch != next_epoch {
284             trace!(
285                 "[handshake:{}] -> changeCipherSpec (epoch: {})",
286                 srv_cli_str(self.state.is_client),
287                 next_epoch
288             );
289             self.set_local_epoch(next_epoch);
290         }
291 
292         Ok(HandshakeState::Sending)
293     }
send(&mut self) -> Result<HandshakeState>294     async fn send(&mut self) -> Result<HandshakeState> {
295         // Send flights
296         if let Some(pkts) = self.flights.clone() {
297             self.write_packets(pkts).await?;
298         }
299 
300         if self.current_flight.is_last_send_flight() {
301             Ok(HandshakeState::Finished)
302         } else {
303             Ok(HandshakeState::Waiting)
304         }
305     }
wait(&mut self) -> Result<HandshakeState>306     async fn wait(&mut self) -> Result<HandshakeState> {
307         let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
308         tokio::pin!(retransmit_timer);
309 
310         loop {
311             tokio::select! {
312                  done = self.handshake_rx.recv() =>{
313                     if done.is_none() {
314                         trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string());
315                         return Err(Error::ErrAlertFatalOrClose);
316                     }
317 
318                     //trace!("[handshake:{}] {} received handshake_rx", srv_cli_str(self.state.is_client), self.current_flight.to_string());
319                     let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
320                     drop(done);
321                     match result {
322                         Err((alert, mut err)) => {
323                             trace!("[handshake:{}] {} result alert:{:?}, err:{:?}",
324                                     srv_cli_str(self.state.is_client),
325                                     self.current_flight.to_string(),
326                                     alert,
327                                     err);
328 
329                             if let Some(alert) = alert {
330                                 let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
331 
332                                 if let Err(alert_err) = alert_err {
333                                     if err.is_some() {
334                                         err = Some(alert_err);
335                                     }
336                                 }
337                             }
338                             if let Some(err) = err {
339                                 return Err(err);
340                             }
341                         }
342                         Ok(next_flight) => {
343                             trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight.to_string(), next_flight.to_string());
344                             if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() {
345                                 return Ok(HandshakeState::Finished);
346                             }
347                             self.current_flight = next_flight;
348                             return Ok(HandshakeState::Preparing);
349                         }
350                     };
351                 }
352 
353                 _ = retransmit_timer.as_mut() =>{
354                     trace!("[handshake:{}] {} retransmit_timer", srv_cli_str(self.state.is_client), self.current_flight.to_string());
355 
356                     if !self.retransmit {
357                         return Ok(HandshakeState::Waiting);
358                     }
359                     return Ok(HandshakeState::Sending);
360                 }
361 
362                 /*_ = self.done_rx.recv() => {
363                     return Err(Error::new("done_rx recv".to_owned()));
364                 }*/
365             }
366         }
367     }
finish(&mut self) -> Result<HandshakeState>368     async fn finish(&mut self) -> Result<HandshakeState> {
369         let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
370 
371         tokio::select! {
372             done = self.handshake_rx.recv() =>{
373                 if done.is_none() {
374                     trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string());
375                     return Err(Error::ErrAlertFatalOrClose);
376                 }
377                 let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
378                 drop(done);
379                 match result {
380                     Err((alert, mut err)) => {
381                         if let Some(alert) = alert {
382                             let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
383                             if let Err(alert_err) = alert_err {
384                                 if err.is_some() {
385                                     err = Some(alert_err);
386                                 }
387                             }
388                         }
389                         if let Some(err) = err {
390                             return Err(err);
391                         }
392                     }
393                     Ok(_) => {
394                         retransmit_timer.await;
395                         // Retransmit last flight
396                         return Ok(HandshakeState::Sending);
397                     }
398                 };
399             }
400 
401             /*_ = self.done_rx.recv() => {
402                 return Err(Error::new("done_rx recv".to_owned()));
403             }*/
404         }
405 
406         Ok(HandshakeState::Finished)
407     }
408 }
409