xref: /webrtc/dtls/src/flight/flight3.rs (revision ffe74184)
1 use super::flight5::*;
2 use super::*;
3 use crate::compression_methods::*;
4 use crate::config::*;
5 use crate::content::*;
6 use crate::curve::named_curve::*;
7 use crate::error::Error;
8 use crate::extension::extension_server_name::*;
9 use crate::extension::extension_supported_elliptic_curves::*;
10 use crate::extension::extension_supported_point_formats::*;
11 use crate::extension::extension_supported_signature_algorithms::*;
12 use crate::extension::extension_use_extended_master_secret::*;
13 use crate::extension::extension_use_srtp::*;
14 use crate::extension::*;
15 use crate::handshake::handshake_message_client_hello::*;
16 use crate::handshake::handshake_message_server_key_exchange::*;
17 use crate::handshake::*;
18 use crate::record_layer::record_layer_header::*;
19 use crate::record_layer::*;
20 
21 use crate::cipher_suite::cipher_suite_for_id;
22 use crate::prf::{prf_pre_master_secret, prf_psk_pre_master_secret};
23 use crate::{find_matching_cipher_suite, find_matching_srtp_profile};
24 
25 use crate::extension::renegotiation_info::ExtensionRenegotiationInfo;
26 use async_trait::async_trait;
27 use log::*;
28 use std::fmt;
29 
30 #[derive(Debug, PartialEq)]
31 pub(crate) struct Flight3;
32 
33 impl fmt::Display for Flight3 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result34     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35         write!(f, "Flight 3")
36     }
37 }
38 
39 #[async_trait]
40 impl Flight for Flight3 {
parse( &self, _tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>41     async fn parse(
42         &self,
43         _tx: &mut mpsc::Sender<mpsc::Sender<()>>,
44         state: &mut State,
45         cache: &HandshakeCache,
46         cfg: &HandshakeConfig,
47     ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> {
48         // Clients may receive multiple HelloVerifyRequest messages with different cookies.
49         // Clients SHOULD handle this by sending a new ClientHello with a cookie in response
50         // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
51         if let Ok((seq, msgs)) = cache
52             .full_pull_map(
53                 state.handshake_recv_sequence,
54                 &[HandshakeCachePullRule {
55                     typ: HandshakeType::HelloVerifyRequest,
56                     epoch: cfg.initial_epoch,
57                     is_client: false,
58                     optional: true,
59                 }],
60             )
61             .await
62         {
63             if let Some(message) = msgs.get(&HandshakeType::HelloVerifyRequest) {
64                 // DTLS 1.2 clients must not assume that the server will use the protocol version
65                 // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
66                 let h = match message {
67                     HandshakeMessage::HelloVerifyRequest(h) => h,
68                     _ => {
69                         return Err((
70                             Some(Alert {
71                                 alert_level: AlertLevel::Fatal,
72                                 alert_description: AlertDescription::InternalError,
73                             }),
74                             None,
75                         ))
76                     }
77                 };
78 
79                 // DTLS 1.2 clients must not assume that the server will use the protocol version
80                 // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
81                 if h.version != PROTOCOL_VERSION1_0 && h.version != PROTOCOL_VERSION1_2 {
82                     return Err((
83                         Some(Alert {
84                             alert_level: AlertLevel::Fatal,
85                             alert_description: AlertDescription::ProtocolVersion,
86                         }),
87                         Some(Error::ErrUnsupportedProtocolVersion),
88                     ));
89                 }
90 
91                 state.cookie = h.cookie.clone();
92                 state.handshake_recv_sequence = seq;
93                 return Ok(Box::new(Flight3 {}) as Box<dyn Flight + Send + Sync>);
94             }
95         }
96 
97         let result = if cfg.local_psk_callback.is_some() {
98             cache
99                 .full_pull_map(
100                     state.handshake_recv_sequence,
101                     &[
102                         HandshakeCachePullRule {
103                             typ: HandshakeType::ServerHello,
104                             epoch: cfg.initial_epoch,
105                             is_client: false,
106                             optional: false,
107                         },
108                         HandshakeCachePullRule {
109                             typ: HandshakeType::ServerKeyExchange,
110                             epoch: cfg.initial_epoch,
111                             is_client: false,
112                             optional: true,
113                         },
114                         HandshakeCachePullRule {
115                             typ: HandshakeType::ServerHelloDone,
116                             epoch: cfg.initial_epoch,
117                             is_client: false,
118                             optional: false,
119                         },
120                     ],
121                 )
122                 .await
123         } else {
124             cache
125                 .full_pull_map(
126                     state.handshake_recv_sequence,
127                     &[
128                         HandshakeCachePullRule {
129                             typ: HandshakeType::ServerHello,
130                             epoch: cfg.initial_epoch,
131                             is_client: false,
132                             optional: false,
133                         },
134                         HandshakeCachePullRule {
135                             typ: HandshakeType::Certificate,
136                             epoch: cfg.initial_epoch,
137                             is_client: false,
138                             optional: true,
139                         },
140                         HandshakeCachePullRule {
141                             typ: HandshakeType::ServerKeyExchange,
142                             epoch: cfg.initial_epoch,
143                             is_client: false,
144                             optional: false,
145                         },
146                         HandshakeCachePullRule {
147                             typ: HandshakeType::CertificateRequest,
148                             epoch: cfg.initial_epoch,
149                             is_client: false,
150                             optional: true,
151                         },
152                         HandshakeCachePullRule {
153                             typ: HandshakeType::ServerHelloDone,
154                             epoch: cfg.initial_epoch,
155                             is_client: false,
156                             optional: false,
157                         },
158                     ],
159                 )
160                 .await
161         };
162 
163         let (seq, msgs) = match result {
164             Ok((seq, msgs)) => (seq, msgs),
165             Err(_) => return Err((None, None)),
166         };
167 
168         state.handshake_recv_sequence = seq;
169 
170         if let Some(message) = msgs.get(&HandshakeType::ServerHello) {
171             let h = match message {
172                 HandshakeMessage::ServerHello(h) => h,
173                 _ => {
174                     return Err((
175                         Some(Alert {
176                             alert_level: AlertLevel::Fatal,
177                             alert_description: AlertDescription::InternalError,
178                         }),
179                         None,
180                     ))
181                 }
182             };
183 
184             if h.version != PROTOCOL_VERSION1_2 {
185                 return Err((
186                     Some(Alert {
187                         alert_level: AlertLevel::Fatal,
188                         alert_description: AlertDescription::ProtocolVersion,
189                     }),
190                     Some(Error::ErrUnsupportedProtocolVersion),
191                 ));
192             }
193 
194             for extension in &h.extensions {
195                 match extension {
196                     Extension::UseSrtp(e) => {
197                         let profile = match find_matching_srtp_profile(
198                             &e.protection_profiles,
199                             &cfg.local_srtp_protection_profiles,
200                         ) {
201                             Ok(profile) => profile,
202                             Err(_) => {
203                                 return Err((
204                                     Some(Alert {
205                                         alert_level: AlertLevel::Fatal,
206                                         alert_description: AlertDescription::IllegalParameter,
207                                     }),
208                                     Some(Error::ErrClientNoMatchingSrtpProfile),
209                                 ))
210                             }
211                         };
212                         state.srtp_protection_profile = profile;
213                     }
214                     Extension::UseExtendedMasterSecret(_) => {
215                         if cfg.extended_master_secret != ExtendedMasterSecretType::Disable {
216                             state.extended_master_secret = true;
217                         }
218                     }
219                     _ => {}
220                 };
221             }
222 
223             if cfg.extended_master_secret == ExtendedMasterSecretType::Require
224                 && !state.extended_master_secret
225             {
226                 return Err((
227                     Some(Alert {
228                         alert_level: AlertLevel::Fatal,
229                         alert_description: AlertDescription::InsufficientSecurity,
230                     }),
231                     Some(Error::ErrClientRequiredButNoServerEms),
232                 ));
233             }
234             if !cfg.local_srtp_protection_profiles.is_empty()
235                 && state.srtp_protection_profile == SrtpProtectionProfile::Unsupported
236             {
237                 return Err((
238                     Some(Alert {
239                         alert_level: AlertLevel::Fatal,
240                         alert_description: AlertDescription::InsufficientSecurity,
241                     }),
242                     Some(Error::ErrRequestedButNoSrtpExtension),
243                 ));
244             }
245             if find_matching_cipher_suite(&[h.cipher_suite], &cfg.local_cipher_suites).is_err() {
246                 debug!(
247                     "[handshake:{}] use cipher suite: {}",
248                     srv_cli_str(state.is_client),
249                     h.cipher_suite
250                 );
251 
252                 return Err((
253                     Some(Alert {
254                         alert_level: AlertLevel::Fatal,
255                         alert_description: AlertDescription::InsufficientSecurity,
256                     }),
257                     Some(Error::ErrCipherSuiteNoIntersection),
258                 ));
259             }
260 
261             let cipher_suite = match cipher_suite_for_id(h.cipher_suite) {
262                 Ok(cipher_suite) => cipher_suite,
263                 Err(_) => {
264                     debug!(
265                         "[handshake:{}] use cipher suite: {}",
266                         srv_cli_str(state.is_client),
267                         h.cipher_suite
268                     );
269 
270                     return Err((
271                         Some(Alert {
272                             alert_level: AlertLevel::Fatal,
273                             alert_description: AlertDescription::InsufficientSecurity,
274                         }),
275                         Some(Error::ErrInvalidCipherSuite),
276                     ));
277                 }
278             };
279 
280             trace!(
281                 "[handshake:{}] use cipher suite: {}",
282                 srv_cli_str(state.is_client),
283                 cipher_suite.to_string()
284             );
285             {
286                 let mut cs = state.cipher_suite.lock().await;
287                 *cs = Some(cipher_suite);
288             }
289             state.remote_random = h.random.clone();
290         }
291 
292         if let Some(message) = msgs.get(&HandshakeType::Certificate) {
293             let h = match message {
294                 HandshakeMessage::Certificate(h) => h,
295                 _ => {
296                     return Err((
297                         Some(Alert {
298                             alert_level: AlertLevel::Fatal,
299                             alert_description: AlertDescription::InternalError,
300                         }),
301                         None,
302                     ))
303                 }
304             };
305             state.peer_certificates = h.certificate.clone();
306         }
307 
308         if let Some(message) = msgs.get(&HandshakeType::ServerKeyExchange) {
309             let h = match message {
310                 HandshakeMessage::ServerKeyExchange(h) => h,
311                 _ => {
312                     return Err((
313                         Some(Alert {
314                             alert_level: AlertLevel::Fatal,
315                             alert_description: AlertDescription::InternalError,
316                         }),
317                         None,
318                     ))
319                 }
320             };
321 
322             if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h) {
323                 return Err((alert, err));
324             }
325         }
326 
327         if let Some(message) = msgs.get(&HandshakeType::CertificateRequest) {
328             match message {
329                 HandshakeMessage::CertificateRequest(_) => {}
330                 _ => {
331                     return Err((
332                         Some(Alert {
333                             alert_level: AlertLevel::Fatal,
334                             alert_description: AlertDescription::InternalError,
335                         }),
336                         None,
337                     ))
338                 }
339             };
340             state.remote_requested_certificate = true;
341         }
342 
343         Ok(Box::new(Flight5 {}) as Box<dyn Flight + Send + Sync>)
344     }
345 
generate( &self, state: &mut State, _cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>346     async fn generate(
347         &self,
348         state: &mut State,
349         _cache: &HandshakeCache,
350         cfg: &HandshakeConfig,
351     ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> {
352         let mut extensions = vec![
353             Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms {
354                 signature_hash_algorithms: cfg.local_signature_schemes.clone(),
355             }),
356             Extension::RenegotiationInfo(ExtensionRenegotiationInfo {
357                 renegotiated_connection: 0,
358             }),
359         ];
360 
361         if cfg.local_psk_callback.is_none() {
362             extensions.extend_from_slice(&[
363                 Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves {
364                     elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384],
365                 }),
366                 Extension::SupportedPointFormats(ExtensionSupportedPointFormats {
367                     point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED],
368                 }),
369             ]);
370         }
371 
372         if !cfg.local_srtp_protection_profiles.is_empty() {
373             extensions.push(Extension::UseSrtp(ExtensionUseSrtp {
374                 protection_profiles: cfg.local_srtp_protection_profiles.clone(),
375             }));
376         }
377 
378         if cfg.extended_master_secret == ExtendedMasterSecretType::Request
379             || cfg.extended_master_secret == ExtendedMasterSecretType::Require
380         {
381             extensions.push(Extension::UseExtendedMasterSecret(
382                 ExtensionUseExtendedMasterSecret { supported: true },
383             ));
384         }
385 
386         if !cfg.server_name.is_empty() {
387             extensions.push(Extension::ServerName(ExtensionServerName {
388                 server_name: cfg.server_name.clone(),
389             }));
390         }
391 
392         Ok(vec![Packet {
393             record: RecordLayer::new(
394                 PROTOCOL_VERSION1_2,
395                 0,
396                 Content::Handshake(Handshake::new(HandshakeMessage::ClientHello(
397                     HandshakeMessageClientHello {
398                         version: PROTOCOL_VERSION1_2,
399                         random: state.local_random.clone(),
400                         cookie: state.cookie.clone(),
401 
402                         cipher_suites: cfg.local_cipher_suites.clone(),
403                         compression_methods: default_compression_methods(),
404                         extensions,
405                     },
406                 ))),
407             ),
408             should_encrypt: false,
409             reset_local_sequence_number: false,
410         }])
411     }
412 }
413 
handle_server_key_exchange( state: &mut State, cfg: &HandshakeConfig, h: &HandshakeMessageServerKeyExchange, ) -> Result<(), (Option<Alert>, Option<Error>)>414 pub(crate) fn handle_server_key_exchange(
415     state: &mut State,
416     cfg: &HandshakeConfig,
417     h: &HandshakeMessageServerKeyExchange,
418 ) -> Result<(), (Option<Alert>, Option<Error>)> {
419     if let Some(local_psk_callback) = &cfg.local_psk_callback {
420         let psk = match local_psk_callback(&h.identity_hint) {
421             Ok(psk) => psk,
422             Err(err) => {
423                 return Err((
424                     Some(Alert {
425                         alert_level: AlertLevel::Fatal,
426                         alert_description: AlertDescription::InternalError,
427                     }),
428                     Some(err),
429                 ))
430             }
431         };
432 
433         state.identity_hint = h.identity_hint.clone();
434         state.pre_master_secret = prf_psk_pre_master_secret(&psk);
435     } else {
436         let local_keypair = match h.named_curve.generate_keypair() {
437             Ok(local_keypair) => local_keypair,
438             Err(err) => {
439                 return Err((
440                     Some(Alert {
441                         alert_level: AlertLevel::Fatal,
442                         alert_description: AlertDescription::InternalError,
443                     }),
444                     Some(err),
445                 ))
446             }
447         };
448 
449         state.pre_master_secret = match prf_pre_master_secret(
450             &h.public_key,
451             &local_keypair.private_key,
452             local_keypair.curve,
453         ) {
454             Ok(pre_master_secret) => pre_master_secret,
455             Err(err) => {
456                 return Err((
457                     Some(Alert {
458                         alert_level: AlertLevel::Fatal,
459                         alert_description: AlertDescription::InternalError,
460                     }),
461                     Some(err),
462                 ))
463             }
464         };
465 
466         state.local_keypair = Some(local_keypair);
467     }
468 
469     Ok(())
470 }
471