1 use crate::cipher_suite::*;
2 use crate::config::*;
3 use crate::conn::*;
4 use crate::content::*;
5 use crate::crypto::*;
6 use crate::error::*;
7 use crate::extension::extension_use_srtp::*;
8 use crate::signature_hash_algorithm::*;
9
10 use log::*;
11 use std::collections::HashMap;
12 use std::fmt;
13 use std::sync::Arc;
14
15 //use std::io::BufWriter;
16
17 // [RFC6347 Section-4.2.4]
18 // +-----------+
19 // +---> | PREPARING | <--------------------+
20 // | +-----------+ |
21 // | | |
22 // | | Buffer next flight |
23 // | | |
24 // | \|/ |
25 // | +-----------+ |
26 // | | SENDING |<------------------+ | Send
27 // | +-----------+ | | HelloRequest
28 // Receive | | | |
29 // next | | Send flight | | or
30 // flight | +--------+ | |
31 // | | | Set retransmit timer | | Receive
32 // | | \|/ | | HelloRequest
33 // | | +-----------+ | | Send
34 // +--)--| WAITING |-------------------+ | ClientHello
35 // | | +-----------+ Timer expires | |
36 // | | | | |
37 // | | +------------------------+ |
38 // Receive | | Send Read retransmit |
39 // last | | last |
40 // flight | | flight |
41 // | | |
42 // \|/\|/ |
43 // +-----------+ |
44 // | FINISHED | -------------------------------+
45 // +-----------+
46 // | /|\
47 // | |
48 // +---+
49 // Read retransmit
50 // Retransmit last flight
51
52 #[derive(Copy, Clone, PartialEq)]
53 pub(crate) enum HandshakeState {
54 Errored,
55 Preparing,
56 Sending,
57 Waiting,
58 Finished,
59 }
60
61 impl fmt::Display for HandshakeState {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match *self {
64 HandshakeState::Errored => write!(f, "Errored"),
65 HandshakeState::Preparing => write!(f, "Preparing"),
66 HandshakeState::Sending => write!(f, "Sending"),
67 HandshakeState::Waiting => write!(f, "Waiting"),
68 HandshakeState::Finished => write!(f, "Finished"),
69 }
70 }
71 }
72
73 pub(crate) type VerifyPeerCertificateFn =
74 Arc<dyn (Fn(&[Vec<u8>], &[rustls::Certificate]) -> Result<()>) + Send + Sync>;
75
76 pub(crate) struct HandshakeConfig {
77 pub(crate) local_psk_callback: Option<PskCallback>,
78 pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
79 pub(crate) local_cipher_suites: Vec<CipherSuiteId>, // Available CipherSuites
80 pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>, // Available signature schemes
81 pub(crate) extended_master_secret: ExtendedMasterSecretType, // Policy for the Extended Master Support extension
82 pub(crate) local_srtp_protection_profiles: Vec<SrtpProtectionProfile>, // Available SRTPProtectionProfiles, if empty no SRTP support
83 pub(crate) server_name: String,
84 pub(crate) client_auth: ClientAuthType, // If we are a client should we request a client certificate
85 pub(crate) local_certificates: Vec<Certificate>,
86 pub(crate) name_to_certificate: HashMap<String, Certificate>,
87 pub(crate) insecure_skip_verify: bool,
88 pub(crate) insecure_verification: bool,
89 pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
90 pub(crate) roots_cas: rustls::RootCertStore,
91 pub(crate) server_cert_verifier: Arc<dyn rustls::ServerCertVerifier>,
92 pub(crate) client_cert_verifier: Option<Arc<dyn rustls::ClientCertVerifier>>,
93 pub(crate) retransmit_interval: tokio::time::Duration,
94 pub(crate) initial_epoch: u16,
95 //log logging.LeveledLogger
96 //mu sync.Mutex
97 }
98
99 impl Default for HandshakeConfig {
default() -> Self100 fn default() -> Self {
101 HandshakeConfig {
102 local_psk_callback: None,
103 local_psk_identity_hint: None,
104 local_cipher_suites: vec![],
105 local_signature_schemes: vec![],
106 extended_master_secret: ExtendedMasterSecretType::Disable,
107 local_srtp_protection_profiles: vec![],
108 server_name: String::new(),
109 client_auth: ClientAuthType::NoClientCert,
110 local_certificates: vec![],
111 name_to_certificate: HashMap::new(),
112 insecure_skip_verify: false,
113 insecure_verification: false,
114 verify_peer_certificate: None,
115 roots_cas: rustls::RootCertStore::empty(),
116 server_cert_verifier: Arc::new(rustls::WebPKIVerifier::new()),
117 client_cert_verifier: None,
118 retransmit_interval: tokio::time::Duration::from_secs(0),
119 initial_epoch: 0,
120 }
121 }
122 }
123
124 impl HandshakeConfig {
get_certificate(&self, server_name: &str) -> Result<Certificate>125 pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate> {
126 //TODO
127 /*if self.name_to_certificate.is_empty() {
128 let mut name_to_certificate = HashMap::new();
129 for cert in &self.local_certificates {
130 if let Ok((_rem, x509_cert)) = x509_parser::parse_x509_der(&cert.certificate) {
131 if let Some(a) = x509_cert.tbs_certificate.subject.iter_common_name().next() {
132 let common_name = match a.attr_value.as_str() {
133 Ok(cn) => cn.to_lowercase(),
134 Err(err) => return Err(Error::new(err.to_string())),
135 };
136 name_to_certificate.insert(common_name, cert.clone());
137 }
138 if let Some((_, sans)) = x509_cert.tbs_certificate.subject_alternative_name() {
139 for gn in &sans.general_names {
140 match gn {
141 x509_parser::extensions::GeneralName::DNSName(san) => {
142 let san = san.to_lowercase();
143 name_to_certificate.insert(san, cert.clone());
144 }
145 _ => {}
146 }
147 }
148 }
149 } else {
150 continue;
151 }
152 }
153 self.name_to_certificate = name_to_certificate;
154 }*/
155
156 if self.local_certificates.is_empty() {
157 return Err(Error::ErrNoCertificates);
158 }
159
160 if self.local_certificates.len() == 1 {
161 // There's only one choice, so no point doing any work.
162 return Ok(self.local_certificates[0].clone());
163 }
164
165 if server_name.is_empty() {
166 return Ok(self.local_certificates[0].clone());
167 }
168
169 let lower = server_name.to_lowercase();
170 let name = lower.trim_end_matches('.');
171
172 if let Some(cert) = self.name_to_certificate.get(name) {
173 return Ok(cert.clone());
174 }
175
176 // try replacing labels in the name with wildcards until we get a
177 // match.
178 let mut labels: Vec<&str> = name.split_terminator('.').collect();
179 for i in 0..labels.len() {
180 labels[i] = "*";
181 let candidate = labels.join(".");
182 if let Some(cert) = self.name_to_certificate.get(&candidate) {
183 return Ok(cert.clone());
184 }
185 }
186
187 // If nothing matches, return the first certificate.
188 Ok(self.local_certificates[0].clone())
189 }
190 }
191
srv_cli_str(is_client: bool) -> String192 pub(crate) fn srv_cli_str(is_client: bool) -> String {
193 if is_client {
194 return "client".to_owned();
195 }
196 "server".to_owned()
197 }
198
199 impl DTLSConn {
handshake(&mut self, mut state: HandshakeState) -> Result<()>200 pub(crate) async fn handshake(&mut self, mut state: HandshakeState) -> Result<()> {
201 loop {
202 trace!(
203 "[handshake:{}] {}: {}",
204 srv_cli_str(self.state.is_client),
205 self.current_flight.to_string(),
206 state.to_string()
207 );
208
209 if state == HandshakeState::Finished && !self.is_handshake_completed_successfully() {
210 self.set_handshake_completed_successfully();
211 self.handshake_done_tx.take(); // drop it by take
212 return Ok(());
213 }
214
215 state = match state {
216 HandshakeState::Preparing => self.prepare().await?,
217 HandshakeState::Sending => self.send().await?,
218 HandshakeState::Waiting => self.wait().await?,
219 HandshakeState::Finished => self.finish().await?,
220 _ => return Err(Error::ErrInvalidFsmTransition),
221 };
222 }
223 }
224
prepare(&mut self) -> Result<HandshakeState>225 async fn prepare(&mut self) -> Result<HandshakeState> {
226 self.flights = None;
227
228 // Prepare flights
229 self.retransmit = self.current_flight.has_retransmit();
230
231 let result = self
232 .current_flight
233 .generate(&mut self.state, &self.cache, &self.cfg)
234 .await;
235
236 match result {
237 Err((a, mut err)) => {
238 if let Some(a) = a {
239 let alert_err = self.notify(a.alert_level, a.alert_description).await;
240
241 if let Err(alert_err) = alert_err {
242 if err.is_some() {
243 err = Some(alert_err);
244 }
245 }
246 }
247 if let Some(err) = err {
248 return Err(err);
249 }
250 }
251 Ok(pkts) => {
252 /*if !pkts.is_empty() {
253 let mut s = vec![];
254 {
255 let mut writer = BufWriter::<&mut Vec<u8>>::new(s.as_mut());
256 pkts[0].record.content.marshal(&mut writer)?;
257 }
258 trace!(
259 "[handshake:{}] {}: {:?}",
260 srv_cli_str(self.state.is_client),
261 self.current_flight.to_string(),
262 s,
263 );
264 }*/
265 self.flights = Some(pkts)
266 }
267 };
268
269 let epoch = self.cfg.initial_epoch;
270 let mut next_epoch = epoch;
271 if let Some(pkts) = &mut self.flights {
272 for p in pkts {
273 p.record.record_layer_header.epoch += epoch;
274 if p.record.record_layer_header.epoch > next_epoch {
275 next_epoch = p.record.record_layer_header.epoch;
276 }
277 if let Content::Handshake(h) = &mut p.record.content {
278 h.handshake_header.message_sequence = self.state.handshake_send_sequence as u16;
279 self.state.handshake_send_sequence += 1;
280 }
281 }
282 }
283 if epoch != next_epoch {
284 trace!(
285 "[handshake:{}] -> changeCipherSpec (epoch: {})",
286 srv_cli_str(self.state.is_client),
287 next_epoch
288 );
289 self.set_local_epoch(next_epoch);
290 }
291
292 Ok(HandshakeState::Sending)
293 }
send(&mut self) -> Result<HandshakeState>294 async fn send(&mut self) -> Result<HandshakeState> {
295 // Send flights
296 if let Some(pkts) = self.flights.clone() {
297 self.write_packets(pkts).await?;
298 }
299
300 if self.current_flight.is_last_send_flight() {
301 Ok(HandshakeState::Finished)
302 } else {
303 Ok(HandshakeState::Waiting)
304 }
305 }
wait(&mut self) -> Result<HandshakeState>306 async fn wait(&mut self) -> Result<HandshakeState> {
307 let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
308 tokio::pin!(retransmit_timer);
309
310 loop {
311 tokio::select! {
312 done = self.handshake_rx.recv() =>{
313 if done.is_none() {
314 trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string());
315 return Err(Error::ErrAlertFatalOrClose);
316 }
317
318 //trace!("[handshake:{}] {} received handshake_rx", srv_cli_str(self.state.is_client), self.current_flight.to_string());
319 let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
320 drop(done);
321 match result {
322 Err((alert, mut err)) => {
323 trace!("[handshake:{}] {} result alert:{:?}, err:{:?}",
324 srv_cli_str(self.state.is_client),
325 self.current_flight.to_string(),
326 alert,
327 err);
328
329 if let Some(alert) = alert {
330 let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
331
332 if let Err(alert_err) = alert_err {
333 if err.is_some() {
334 err = Some(alert_err);
335 }
336 }
337 }
338 if let Some(err) = err {
339 return Err(err);
340 }
341 }
342 Ok(next_flight) => {
343 trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight.to_string(), next_flight.to_string());
344 if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() {
345 return Ok(HandshakeState::Finished);
346 }
347 self.current_flight = next_flight;
348 return Ok(HandshakeState::Preparing);
349 }
350 };
351 }
352
353 _ = retransmit_timer.as_mut() =>{
354 trace!("[handshake:{}] {} retransmit_timer", srv_cli_str(self.state.is_client), self.current_flight.to_string());
355
356 if !self.retransmit {
357 return Ok(HandshakeState::Waiting);
358 }
359 return Ok(HandshakeState::Sending);
360 }
361
362 /*_ = self.done_rx.recv() => {
363 return Err(Error::new("done_rx recv".to_owned()));
364 }*/
365 }
366 }
367 }
finish(&mut self) -> Result<HandshakeState>368 async fn finish(&mut self) -> Result<HandshakeState> {
369 let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
370
371 tokio::select! {
372 done = self.handshake_rx.recv() =>{
373 if done.is_none() {
374 trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string());
375 return Err(Error::ErrAlertFatalOrClose);
376 }
377 let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
378 drop(done);
379 match result {
380 Err((alert, mut err)) => {
381 if let Some(alert) = alert {
382 let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
383 if let Err(alert_err) = alert_err {
384 if err.is_some() {
385 err = Some(alert_err);
386 }
387 }
388 }
389 if let Some(err) = err {
390 return Err(err);
391 }
392 }
393 Ok(_) => {
394 retransmit_timer.await;
395 // Retransmit last flight
396 return Ok(HandshakeState::Sending);
397 }
398 };
399 }
400
401 /*_ = self.done_rx.recv() => {
402 return Err(Error::new("done_rx recv".to_owned()));
403 }*/
404 }
405
406 Ok(HandshakeState::Finished)
407 }
408 }
409