1 use super::flight5::*;
2 use super::*;
3 use crate::compression_methods::*;
4 use crate::config::*;
5 use crate::content::*;
6 use crate::curve::named_curve::*;
7 use crate::error::Error;
8 use crate::extension::extension_server_name::*;
9 use crate::extension::extension_supported_elliptic_curves::*;
10 use crate::extension::extension_supported_point_formats::*;
11 use crate::extension::extension_supported_signature_algorithms::*;
12 use crate::extension::extension_use_extended_master_secret::*;
13 use crate::extension::extension_use_srtp::*;
14 use crate::extension::*;
15 use crate::handshake::handshake_message_client_hello::*;
16 use crate::handshake::handshake_message_server_key_exchange::*;
17 use crate::handshake::*;
18 use crate::record_layer::record_layer_header::*;
19 use crate::record_layer::*;
20
21 use crate::cipher_suite::cipher_suite_for_id;
22 use crate::prf::{prf_pre_master_secret, prf_psk_pre_master_secret};
23 use crate::{find_matching_cipher_suite, find_matching_srtp_profile};
24
25 use crate::extension::renegotiation_info::ExtensionRenegotiationInfo;
26 use async_trait::async_trait;
27 use log::*;
28 use std::fmt;
29
30 #[derive(Debug, PartialEq)]
31 pub(crate) struct Flight3;
32
33 impl fmt::Display for Flight3 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 write!(f, "Flight 3")
36 }
37 }
38
39 #[async_trait]
40 impl Flight for Flight3 {
parse( &self, _tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>41 async fn parse(
42 &self,
43 _tx: &mut mpsc::Sender<mpsc::Sender<()>>,
44 state: &mut State,
45 cache: &HandshakeCache,
46 cfg: &HandshakeConfig,
47 ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> {
48 // Clients may receive multiple HelloVerifyRequest messages with different cookies.
49 // Clients SHOULD handle this by sending a new ClientHello with a cookie in response
50 // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
51 if let Ok((seq, msgs)) = cache
52 .full_pull_map(
53 state.handshake_recv_sequence,
54 &[HandshakeCachePullRule {
55 typ: HandshakeType::HelloVerifyRequest,
56 epoch: cfg.initial_epoch,
57 is_client: false,
58 optional: true,
59 }],
60 )
61 .await
62 {
63 if let Some(message) = msgs.get(&HandshakeType::HelloVerifyRequest) {
64 // DTLS 1.2 clients must not assume that the server will use the protocol version
65 // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
66 let h = match message {
67 HandshakeMessage::HelloVerifyRequest(h) => h,
68 _ => {
69 return Err((
70 Some(Alert {
71 alert_level: AlertLevel::Fatal,
72 alert_description: AlertDescription::InternalError,
73 }),
74 None,
75 ))
76 }
77 };
78
79 // DTLS 1.2 clients must not assume that the server will use the protocol version
80 // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
81 if h.version != PROTOCOL_VERSION1_0 && h.version != PROTOCOL_VERSION1_2 {
82 return Err((
83 Some(Alert {
84 alert_level: AlertLevel::Fatal,
85 alert_description: AlertDescription::ProtocolVersion,
86 }),
87 Some(Error::ErrUnsupportedProtocolVersion),
88 ));
89 }
90
91 state.cookie = h.cookie.clone();
92 state.handshake_recv_sequence = seq;
93 return Ok(Box::new(Flight3 {}) as Box<dyn Flight + Send + Sync>);
94 }
95 }
96
97 let result = if cfg.local_psk_callback.is_some() {
98 cache
99 .full_pull_map(
100 state.handshake_recv_sequence,
101 &[
102 HandshakeCachePullRule {
103 typ: HandshakeType::ServerHello,
104 epoch: cfg.initial_epoch,
105 is_client: false,
106 optional: false,
107 },
108 HandshakeCachePullRule {
109 typ: HandshakeType::ServerKeyExchange,
110 epoch: cfg.initial_epoch,
111 is_client: false,
112 optional: true,
113 },
114 HandshakeCachePullRule {
115 typ: HandshakeType::ServerHelloDone,
116 epoch: cfg.initial_epoch,
117 is_client: false,
118 optional: false,
119 },
120 ],
121 )
122 .await
123 } else {
124 cache
125 .full_pull_map(
126 state.handshake_recv_sequence,
127 &[
128 HandshakeCachePullRule {
129 typ: HandshakeType::ServerHello,
130 epoch: cfg.initial_epoch,
131 is_client: false,
132 optional: false,
133 },
134 HandshakeCachePullRule {
135 typ: HandshakeType::Certificate,
136 epoch: cfg.initial_epoch,
137 is_client: false,
138 optional: true,
139 },
140 HandshakeCachePullRule {
141 typ: HandshakeType::ServerKeyExchange,
142 epoch: cfg.initial_epoch,
143 is_client: false,
144 optional: false,
145 },
146 HandshakeCachePullRule {
147 typ: HandshakeType::CertificateRequest,
148 epoch: cfg.initial_epoch,
149 is_client: false,
150 optional: true,
151 },
152 HandshakeCachePullRule {
153 typ: HandshakeType::ServerHelloDone,
154 epoch: cfg.initial_epoch,
155 is_client: false,
156 optional: false,
157 },
158 ],
159 )
160 .await
161 };
162
163 let (seq, msgs) = match result {
164 Ok((seq, msgs)) => (seq, msgs),
165 Err(_) => return Err((None, None)),
166 };
167
168 state.handshake_recv_sequence = seq;
169
170 if let Some(message) = msgs.get(&HandshakeType::ServerHello) {
171 let h = match message {
172 HandshakeMessage::ServerHello(h) => h,
173 _ => {
174 return Err((
175 Some(Alert {
176 alert_level: AlertLevel::Fatal,
177 alert_description: AlertDescription::InternalError,
178 }),
179 None,
180 ))
181 }
182 };
183
184 if h.version != PROTOCOL_VERSION1_2 {
185 return Err((
186 Some(Alert {
187 alert_level: AlertLevel::Fatal,
188 alert_description: AlertDescription::ProtocolVersion,
189 }),
190 Some(Error::ErrUnsupportedProtocolVersion),
191 ));
192 }
193
194 for extension in &h.extensions {
195 match extension {
196 Extension::UseSrtp(e) => {
197 let profile = match find_matching_srtp_profile(
198 &e.protection_profiles,
199 &cfg.local_srtp_protection_profiles,
200 ) {
201 Ok(profile) => profile,
202 Err(_) => {
203 return Err((
204 Some(Alert {
205 alert_level: AlertLevel::Fatal,
206 alert_description: AlertDescription::IllegalParameter,
207 }),
208 Some(Error::ErrClientNoMatchingSrtpProfile),
209 ))
210 }
211 };
212 state.srtp_protection_profile = profile;
213 }
214 Extension::UseExtendedMasterSecret(_) => {
215 if cfg.extended_master_secret != ExtendedMasterSecretType::Disable {
216 state.extended_master_secret = true;
217 }
218 }
219 _ => {}
220 };
221 }
222
223 if cfg.extended_master_secret == ExtendedMasterSecretType::Require
224 && !state.extended_master_secret
225 {
226 return Err((
227 Some(Alert {
228 alert_level: AlertLevel::Fatal,
229 alert_description: AlertDescription::InsufficientSecurity,
230 }),
231 Some(Error::ErrClientRequiredButNoServerEms),
232 ));
233 }
234 if !cfg.local_srtp_protection_profiles.is_empty()
235 && state.srtp_protection_profile == SrtpProtectionProfile::Unsupported
236 {
237 return Err((
238 Some(Alert {
239 alert_level: AlertLevel::Fatal,
240 alert_description: AlertDescription::InsufficientSecurity,
241 }),
242 Some(Error::ErrRequestedButNoSrtpExtension),
243 ));
244 }
245 if find_matching_cipher_suite(&[h.cipher_suite], &cfg.local_cipher_suites).is_err() {
246 debug!(
247 "[handshake:{}] use cipher suite: {}",
248 srv_cli_str(state.is_client),
249 h.cipher_suite
250 );
251
252 return Err((
253 Some(Alert {
254 alert_level: AlertLevel::Fatal,
255 alert_description: AlertDescription::InsufficientSecurity,
256 }),
257 Some(Error::ErrCipherSuiteNoIntersection),
258 ));
259 }
260
261 let cipher_suite = match cipher_suite_for_id(h.cipher_suite) {
262 Ok(cipher_suite) => cipher_suite,
263 Err(_) => {
264 debug!(
265 "[handshake:{}] use cipher suite: {}",
266 srv_cli_str(state.is_client),
267 h.cipher_suite
268 );
269
270 return Err((
271 Some(Alert {
272 alert_level: AlertLevel::Fatal,
273 alert_description: AlertDescription::InsufficientSecurity,
274 }),
275 Some(Error::ErrInvalidCipherSuite),
276 ));
277 }
278 };
279
280 trace!(
281 "[handshake:{}] use cipher suite: {}",
282 srv_cli_str(state.is_client),
283 cipher_suite.to_string()
284 );
285 {
286 let mut cs = state.cipher_suite.lock().await;
287 *cs = Some(cipher_suite);
288 }
289 state.remote_random = h.random.clone();
290 }
291
292 if let Some(message) = msgs.get(&HandshakeType::Certificate) {
293 let h = match message {
294 HandshakeMessage::Certificate(h) => h,
295 _ => {
296 return Err((
297 Some(Alert {
298 alert_level: AlertLevel::Fatal,
299 alert_description: AlertDescription::InternalError,
300 }),
301 None,
302 ))
303 }
304 };
305 state.peer_certificates = h.certificate.clone();
306 }
307
308 if let Some(message) = msgs.get(&HandshakeType::ServerKeyExchange) {
309 let h = match message {
310 HandshakeMessage::ServerKeyExchange(h) => h,
311 _ => {
312 return Err((
313 Some(Alert {
314 alert_level: AlertLevel::Fatal,
315 alert_description: AlertDescription::InternalError,
316 }),
317 None,
318 ))
319 }
320 };
321
322 if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h) {
323 return Err((alert, err));
324 }
325 }
326
327 if let Some(message) = msgs.get(&HandshakeType::CertificateRequest) {
328 match message {
329 HandshakeMessage::CertificateRequest(_) => {}
330 _ => {
331 return Err((
332 Some(Alert {
333 alert_level: AlertLevel::Fatal,
334 alert_description: AlertDescription::InternalError,
335 }),
336 None,
337 ))
338 }
339 };
340 state.remote_requested_certificate = true;
341 }
342
343 Ok(Box::new(Flight5 {}) as Box<dyn Flight + Send + Sync>)
344 }
345
generate( &self, state: &mut State, _cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>346 async fn generate(
347 &self,
348 state: &mut State,
349 _cache: &HandshakeCache,
350 cfg: &HandshakeConfig,
351 ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> {
352 let mut extensions = vec![
353 Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms {
354 signature_hash_algorithms: cfg.local_signature_schemes.clone(),
355 }),
356 Extension::RenegotiationInfo(ExtensionRenegotiationInfo {
357 renegotiated_connection: 0,
358 }),
359 ];
360
361 if cfg.local_psk_callback.is_none() {
362 extensions.extend_from_slice(&[
363 Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves {
364 elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384],
365 }),
366 Extension::SupportedPointFormats(ExtensionSupportedPointFormats {
367 point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED],
368 }),
369 ]);
370 }
371
372 if !cfg.local_srtp_protection_profiles.is_empty() {
373 extensions.push(Extension::UseSrtp(ExtensionUseSrtp {
374 protection_profiles: cfg.local_srtp_protection_profiles.clone(),
375 }));
376 }
377
378 if cfg.extended_master_secret == ExtendedMasterSecretType::Request
379 || cfg.extended_master_secret == ExtendedMasterSecretType::Require
380 {
381 extensions.push(Extension::UseExtendedMasterSecret(
382 ExtensionUseExtendedMasterSecret { supported: true },
383 ));
384 }
385
386 if !cfg.server_name.is_empty() {
387 extensions.push(Extension::ServerName(ExtensionServerName {
388 server_name: cfg.server_name.clone(),
389 }));
390 }
391
392 Ok(vec![Packet {
393 record: RecordLayer::new(
394 PROTOCOL_VERSION1_2,
395 0,
396 Content::Handshake(Handshake::new(HandshakeMessage::ClientHello(
397 HandshakeMessageClientHello {
398 version: PROTOCOL_VERSION1_2,
399 random: state.local_random.clone(),
400 cookie: state.cookie.clone(),
401
402 cipher_suites: cfg.local_cipher_suites.clone(),
403 compression_methods: default_compression_methods(),
404 extensions,
405 },
406 ))),
407 ),
408 should_encrypt: false,
409 reset_local_sequence_number: false,
410 }])
411 }
412 }
413
handle_server_key_exchange( state: &mut State, cfg: &HandshakeConfig, h: &HandshakeMessageServerKeyExchange, ) -> Result<(), (Option<Alert>, Option<Error>)>414 pub(crate) fn handle_server_key_exchange(
415 state: &mut State,
416 cfg: &HandshakeConfig,
417 h: &HandshakeMessageServerKeyExchange,
418 ) -> Result<(), (Option<Alert>, Option<Error>)> {
419 if let Some(local_psk_callback) = &cfg.local_psk_callback {
420 let psk = match local_psk_callback(&h.identity_hint) {
421 Ok(psk) => psk,
422 Err(err) => {
423 return Err((
424 Some(Alert {
425 alert_level: AlertLevel::Fatal,
426 alert_description: AlertDescription::InternalError,
427 }),
428 Some(err),
429 ))
430 }
431 };
432
433 state.identity_hint = h.identity_hint.clone();
434 state.pre_master_secret = prf_psk_pre_master_secret(&psk);
435 } else {
436 let local_keypair = match h.named_curve.generate_keypair() {
437 Ok(local_keypair) => local_keypair,
438 Err(err) => {
439 return Err((
440 Some(Alert {
441 alert_level: AlertLevel::Fatal,
442 alert_description: AlertDescription::InternalError,
443 }),
444 Some(err),
445 ))
446 }
447 };
448
449 state.pre_master_secret = match prf_pre_master_secret(
450 &h.public_key,
451 &local_keypair.private_key,
452 local_keypair.curve,
453 ) {
454 Ok(pre_master_secret) => pre_master_secret,
455 Err(err) => {
456 return Err((
457 Some(Alert {
458 alert_level: AlertLevel::Fatal,
459 alert_description: AlertDescription::InternalError,
460 }),
461 Some(err),
462 ))
463 }
464 };
465
466 state.local_keypair = Some(local_keypair);
467 }
468
469 Ok(())
470 }
471