xref: /webrtc/dtls/src/flight/flight1.rs (revision ffe74184)
1 use super::flight3::*;
2 use super::*;
3 use crate::compression_methods::*;
4 use crate::config::*;
5 use crate::conn::*;
6 use crate::content::*;
7 use crate::curve::named_curve::*;
8 use crate::error::Error;
9 use crate::extension::extension_server_name::*;
10 use crate::extension::extension_supported_elliptic_curves::*;
11 use crate::extension::extension_supported_point_formats::*;
12 use crate::extension::extension_supported_signature_algorithms::*;
13 use crate::extension::extension_use_extended_master_secret::*;
14 use crate::extension::extension_use_srtp::*;
15 use crate::extension::*;
16 use crate::handshake::handshake_message_client_hello::*;
17 use crate::handshake::*;
18 use crate::record_layer::record_layer_header::*;
19 use crate::record_layer::*;
20 
21 use crate::extension::renegotiation_info::ExtensionRenegotiationInfo;
22 use async_trait::async_trait;
23 use std::fmt;
24 use std::sync::atomic::Ordering;
25 
26 #[derive(Debug, PartialEq)]
27 pub(crate) struct Flight1;
28 
29 impl fmt::Display for Flight1 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result30     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31         write!(f, "Flight 1")
32     }
33 }
34 
35 #[async_trait]
36 impl Flight for Flight1 {
parse( &self, tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>37     async fn parse(
38         &self,
39         tx: &mut mpsc::Sender<mpsc::Sender<()>>,
40         state: &mut State,
41         cache: &HandshakeCache,
42         cfg: &HandshakeConfig,
43     ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> {
44         // HelloVerifyRequest can be skipped by the server,
45         // so allow ServerHello during flight1 also
46         let (seq, msgs) = match cache
47             .full_pull_map(
48                 state.handshake_recv_sequence,
49                 &[
50                     HandshakeCachePullRule {
51                         typ: HandshakeType::HelloVerifyRequest,
52                         epoch: cfg.initial_epoch,
53                         is_client: false,
54                         optional: true,
55                     },
56                     HandshakeCachePullRule {
57                         typ: HandshakeType::ServerHello,
58                         epoch: cfg.initial_epoch,
59                         is_client: false,
60                         optional: true,
61                     },
62                 ],
63             )
64             .await
65         {
66             // No valid message received. Keep reading
67             Ok((seq, msgs)) => (seq, msgs),
68             Err(_) => return Err((None, None)),
69         };
70 
71         if msgs.contains_key(&HandshakeType::ServerHello) {
72             // Flight1 and flight2 were skipped.
73             // Parse as flight3.
74             let flight3 = Flight3 {};
75             return flight3.parse(tx, state, cache, cfg).await;
76         }
77 
78         if let Some(message) = msgs.get(&HandshakeType::HelloVerifyRequest) {
79             // DTLS 1.2 clients must not assume that the server will use the protocol version
80             // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
81             let h = match message {
82                 HandshakeMessage::HelloVerifyRequest(h) => h,
83                 _ => {
84                     return Err((
85                         Some(Alert {
86                             alert_level: AlertLevel::Fatal,
87                             alert_description: AlertDescription::InternalError,
88                         }),
89                         None,
90                     ))
91                 }
92             };
93 
94             if h.version != PROTOCOL_VERSION1_0 && h.version != PROTOCOL_VERSION1_2 {
95                 return Err((
96                     Some(Alert {
97                         alert_level: AlertLevel::Fatal,
98                         alert_description: AlertDescription::ProtocolVersion,
99                     }),
100                     Some(Error::ErrUnsupportedProtocolVersion),
101                 ));
102             }
103 
104             state.cookie = h.cookie.clone();
105             state.handshake_recv_sequence = seq;
106             Ok(Box::new(Flight3 {}))
107         } else {
108             Err((
109                 Some(Alert {
110                     alert_level: AlertLevel::Fatal,
111                     alert_description: AlertDescription::InternalError,
112                 }),
113                 None,
114             ))
115         }
116     }
117 
generate( &self, state: &mut State, _cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>118     async fn generate(
119         &self,
120         state: &mut State,
121         _cache: &HandshakeCache,
122         cfg: &HandshakeConfig,
123     ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> {
124         let zero_epoch = 0;
125         state.local_epoch.store(zero_epoch, Ordering::SeqCst);
126         state.remote_epoch.store(zero_epoch, Ordering::SeqCst);
127 
128         state.named_curve = DEFAULT_NAMED_CURVE;
129         state.cookie = vec![];
130         state.local_random.populate();
131 
132         let mut extensions = vec![
133             Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms {
134                 signature_hash_algorithms: cfg.local_signature_schemes.clone(),
135             }),
136             Extension::RenegotiationInfo(ExtensionRenegotiationInfo {
137                 renegotiated_connection: 0,
138             }),
139         ];
140 
141         if cfg.local_psk_callback.is_none() {
142             extensions.extend_from_slice(&[
143                 Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves {
144                     elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384],
145                 }),
146                 Extension::SupportedPointFormats(ExtensionSupportedPointFormats {
147                     point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED],
148                 }),
149             ]);
150         }
151 
152         if !cfg.local_srtp_protection_profiles.is_empty() {
153             extensions.push(Extension::UseSrtp(ExtensionUseSrtp {
154                 protection_profiles: cfg.local_srtp_protection_profiles.clone(),
155             }));
156         }
157 
158         if cfg.extended_master_secret == ExtendedMasterSecretType::Request
159             || cfg.extended_master_secret == ExtendedMasterSecretType::Require
160         {
161             extensions.push(Extension::UseExtendedMasterSecret(
162                 ExtensionUseExtendedMasterSecret { supported: true },
163             ));
164         }
165 
166         if !cfg.server_name.is_empty() {
167             extensions.push(Extension::ServerName(ExtensionServerName {
168                 server_name: cfg.server_name.clone(),
169             }));
170         }
171 
172         Ok(vec![Packet {
173             record: RecordLayer::new(
174                 PROTOCOL_VERSION1_2,
175                 0,
176                 Content::Handshake(Handshake::new(HandshakeMessage::ClientHello(
177                     HandshakeMessageClientHello {
178                         version: PROTOCOL_VERSION1_2,
179                         random: state.local_random.clone(),
180                         cookie: state.cookie.clone(),
181 
182                         cipher_suites: cfg.local_cipher_suites.clone(),
183                         compression_methods: default_compression_methods(),
184                         extensions,
185                     },
186                 ))),
187             ),
188             should_encrypt: false,
189             reset_local_sequence_number: false,
190         }])
191     }
192 }
193