1 use super::flight3::*; 2 use super::*; 3 use crate::compression_methods::*; 4 use crate::config::*; 5 use crate::conn::*; 6 use crate::content::*; 7 use crate::curve::named_curve::*; 8 use crate::error::Error; 9 use crate::extension::extension_server_name::*; 10 use crate::extension::extension_supported_elliptic_curves::*; 11 use crate::extension::extension_supported_point_formats::*; 12 use crate::extension::extension_supported_signature_algorithms::*; 13 use crate::extension::extension_use_extended_master_secret::*; 14 use crate::extension::extension_use_srtp::*; 15 use crate::extension::*; 16 use crate::handshake::handshake_message_client_hello::*; 17 use crate::handshake::*; 18 use crate::record_layer::record_layer_header::*; 19 use crate::record_layer::*; 20 21 use crate::extension::renegotiation_info::ExtensionRenegotiationInfo; 22 use async_trait::async_trait; 23 use std::fmt; 24 use std::sync::atomic::Ordering; 25 26 #[derive(Debug, PartialEq)] 27 pub(crate) struct Flight1; 28 29 impl fmt::Display for Flight1 { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 31 write!(f, "Flight 1") 32 } 33 } 34 35 #[async_trait] 36 impl Flight for Flight1 { parse( &self, tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>37 async fn parse( 38 &self, 39 tx: &mut mpsc::Sender<mpsc::Sender<()>>, 40 state: &mut State, 41 cache: &HandshakeCache, 42 cfg: &HandshakeConfig, 43 ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> { 44 // HelloVerifyRequest can be skipped by the server, 45 // so allow ServerHello during flight1 also 46 let (seq, msgs) = match cache 47 .full_pull_map( 48 state.handshake_recv_sequence, 49 &[ 50 HandshakeCachePullRule { 51 typ: HandshakeType::HelloVerifyRequest, 52 epoch: cfg.initial_epoch, 53 is_client: false, 54 optional: true, 55 }, 56 HandshakeCachePullRule { 57 typ: HandshakeType::ServerHello, 58 epoch: cfg.initial_epoch, 59 is_client: false, 60 optional: true, 61 }, 62 ], 63 ) 64 .await 65 { 66 // No valid message received. Keep reading 67 Ok((seq, msgs)) => (seq, msgs), 68 Err(_) => return Err((None, None)), 69 }; 70 71 if msgs.contains_key(&HandshakeType::ServerHello) { 72 // Flight1 and flight2 were skipped. 73 // Parse as flight3. 74 let flight3 = Flight3 {}; 75 return flight3.parse(tx, state, cache, cfg).await; 76 } 77 78 if let Some(message) = msgs.get(&HandshakeType::HelloVerifyRequest) { 79 // DTLS 1.2 clients must not assume that the server will use the protocol version 80 // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 81 let h = match message { 82 HandshakeMessage::HelloVerifyRequest(h) => h, 83 _ => { 84 return Err(( 85 Some(Alert { 86 alert_level: AlertLevel::Fatal, 87 alert_description: AlertDescription::InternalError, 88 }), 89 None, 90 )) 91 } 92 }; 93 94 if h.version != PROTOCOL_VERSION1_0 && h.version != PROTOCOL_VERSION1_2 { 95 return Err(( 96 Some(Alert { 97 alert_level: AlertLevel::Fatal, 98 alert_description: AlertDescription::ProtocolVersion, 99 }), 100 Some(Error::ErrUnsupportedProtocolVersion), 101 )); 102 } 103 104 state.cookie = h.cookie.clone(); 105 state.handshake_recv_sequence = seq; 106 Ok(Box::new(Flight3 {})) 107 } else { 108 Err(( 109 Some(Alert { 110 alert_level: AlertLevel::Fatal, 111 alert_description: AlertDescription::InternalError, 112 }), 113 None, 114 )) 115 } 116 } 117 generate( &self, state: &mut State, _cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>118 async fn generate( 119 &self, 120 state: &mut State, 121 _cache: &HandshakeCache, 122 cfg: &HandshakeConfig, 123 ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> { 124 let zero_epoch = 0; 125 state.local_epoch.store(zero_epoch, Ordering::SeqCst); 126 state.remote_epoch.store(zero_epoch, Ordering::SeqCst); 127 128 state.named_curve = DEFAULT_NAMED_CURVE; 129 state.cookie = vec![]; 130 state.local_random.populate(); 131 132 let mut extensions = vec![ 133 Extension::SupportedSignatureAlgorithms(ExtensionSupportedSignatureAlgorithms { 134 signature_hash_algorithms: cfg.local_signature_schemes.clone(), 135 }), 136 Extension::RenegotiationInfo(ExtensionRenegotiationInfo { 137 renegotiated_connection: 0, 138 }), 139 ]; 140 141 if cfg.local_psk_callback.is_none() { 142 extensions.extend_from_slice(&[ 143 Extension::SupportedEllipticCurves(ExtensionSupportedEllipticCurves { 144 elliptic_curves: vec![NamedCurve::P256, NamedCurve::X25519, NamedCurve::P384], 145 }), 146 Extension::SupportedPointFormats(ExtensionSupportedPointFormats { 147 point_formats: vec![ELLIPTIC_CURVE_POINT_FORMAT_UNCOMPRESSED], 148 }), 149 ]); 150 } 151 152 if !cfg.local_srtp_protection_profiles.is_empty() { 153 extensions.push(Extension::UseSrtp(ExtensionUseSrtp { 154 protection_profiles: cfg.local_srtp_protection_profiles.clone(), 155 })); 156 } 157 158 if cfg.extended_master_secret == ExtendedMasterSecretType::Request 159 || cfg.extended_master_secret == ExtendedMasterSecretType::Require 160 { 161 extensions.push(Extension::UseExtendedMasterSecret( 162 ExtensionUseExtendedMasterSecret { supported: true }, 163 )); 164 } 165 166 if !cfg.server_name.is_empty() { 167 extensions.push(Extension::ServerName(ExtensionServerName { 168 server_name: cfg.server_name.clone(), 169 })); 170 } 171 172 Ok(vec![Packet { 173 record: RecordLayer::new( 174 PROTOCOL_VERSION1_2, 175 0, 176 Content::Handshake(Handshake::new(HandshakeMessage::ClientHello( 177 HandshakeMessageClientHello { 178 version: PROTOCOL_VERSION1_2, 179 random: state.local_random.clone(), 180 cookie: state.cookie.clone(), 181 182 cipher_suites: cfg.local_cipher_suites.clone(), 183 compression_methods: default_compression_methods(), 184 extensions, 185 }, 186 ))), 187 ), 188 should_encrypt: false, 189 reset_local_sequence_number: false, 190 }]) 191 } 192 } 193