1 use super::flight2::*; 2 use super::*; 3 use crate::config::*; 4 use crate::conn::*; 5 use crate::error::Error; 6 use crate::extension::*; 7 use crate::handshake::*; 8 use crate::record_layer::record_layer_header::*; 9 use crate::*; 10 11 use async_trait::async_trait; 12 use rand::Rng; 13 use std::fmt; 14 use std::sync::atomic::Ordering; 15 16 #[derive(Debug, PartialEq)] 17 pub(crate) struct Flight0; 18 19 impl fmt::Display for Flight0 { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 21 write!(f, "Flight 0") 22 } 23 } 24 25 #[async_trait] 26 impl Flight for Flight0 { parse( &self, _tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>27 async fn parse( 28 &self, 29 _tx: &mut mpsc::Sender<mpsc::Sender<()>>, 30 state: &mut State, 31 cache: &HandshakeCache, 32 cfg: &HandshakeConfig, 33 ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> { 34 let (seq, msgs) = match cache 35 .full_pull_map( 36 0, 37 &[HandshakeCachePullRule { 38 typ: HandshakeType::ClientHello, 39 epoch: cfg.initial_epoch, 40 is_client: true, 41 optional: false, 42 }], 43 ) 44 .await 45 { 46 Ok((seq, msgs)) => (seq, msgs), 47 Err(_) => return Err((None, None)), 48 }; 49 50 state.handshake_recv_sequence = seq; 51 52 if let Some(message) = msgs.get(&HandshakeType::ClientHello) { 53 // Validate type 54 let client_hello = match message { 55 HandshakeMessage::ClientHello(client_hello) => client_hello, 56 _ => { 57 return Err(( 58 Some(Alert { 59 alert_level: AlertLevel::Fatal, 60 alert_description: AlertDescription::InternalError, 61 }), 62 None, 63 )) 64 } 65 }; 66 67 if client_hello.version != PROTOCOL_VERSION1_2 { 68 return Err(( 69 Some(Alert { 70 alert_level: AlertLevel::Fatal, 71 alert_description: AlertDescription::ProtocolVersion, 72 }), 73 Some(Error::ErrUnsupportedProtocolVersion), 74 )); 75 } 76 77 state.remote_random = client_hello.random.clone(); 78 79 if let Ok(id) = 80 find_matching_cipher_suite(&client_hello.cipher_suites, &cfg.local_cipher_suites) 81 { 82 if let Ok(cipher_suite) = cipher_suite_for_id(id) { 83 log::debug!( 84 "[handshake:{}] use cipher suite: {}", 85 srv_cli_str(state.is_client), 86 cipher_suite.to_string() 87 ); 88 let mut cs = state.cipher_suite.lock().await; 89 *cs = Some(cipher_suite); 90 } 91 } else { 92 return Err(( 93 Some(Alert { 94 alert_level: AlertLevel::Fatal, 95 alert_description: AlertDescription::InsufficientSecurity, 96 }), 97 Some(Error::ErrCipherSuiteNoIntersection), 98 )); 99 } 100 101 for extension in &client_hello.extensions { 102 match extension { 103 Extension::SupportedEllipticCurves(e) => { 104 if e.elliptic_curves.is_empty() { 105 return Err(( 106 Some(Alert { 107 alert_level: AlertLevel::Fatal, 108 alert_description: AlertDescription::InsufficientSecurity, 109 }), 110 Some(Error::ErrNoSupportedEllipticCurves), 111 )); 112 } 113 state.named_curve = e.elliptic_curves[0]; 114 } 115 Extension::UseSrtp(e) => { 116 if let Ok(profile) = find_matching_srtp_profile( 117 &e.protection_profiles, 118 &cfg.local_srtp_protection_profiles, 119 ) { 120 state.srtp_protection_profile = profile; 121 } else { 122 return Err(( 123 Some(Alert { 124 alert_level: AlertLevel::Fatal, 125 alert_description: AlertDescription::InsufficientSecurity, 126 }), 127 Some(Error::ErrServerNoMatchingSrtpProfile), 128 )); 129 } 130 } 131 Extension::UseExtendedMasterSecret(_) => { 132 if cfg.extended_master_secret != ExtendedMasterSecretType::Disable { 133 state.extended_master_secret = true; 134 } 135 } 136 Extension::ServerName(e) => { 137 state.server_name = e.server_name.clone(); // remote server name 138 } 139 _ => {} 140 } 141 } 142 143 if cfg.extended_master_secret == ExtendedMasterSecretType::Require 144 && !state.extended_master_secret 145 { 146 return Err(( 147 Some(Alert { 148 alert_level: AlertLevel::Fatal, 149 alert_description: AlertDescription::InsufficientSecurity, 150 }), 151 Some(Error::ErrServerRequiredButNoClientEms), 152 )); 153 } 154 155 if state.local_keypair.is_none() { 156 state.local_keypair = match state.named_curve.generate_keypair() { 157 Ok(local_keypar) => Some(local_keypar), 158 Err(err) => { 159 return Err(( 160 Some(Alert { 161 alert_level: AlertLevel::Fatal, 162 alert_description: AlertDescription::IllegalParameter, 163 }), 164 Some(err), 165 )) 166 } 167 }; 168 } 169 170 Ok(Box::new(Flight2 {})) 171 } else { 172 Err(( 173 Some(Alert { 174 alert_level: AlertLevel::Fatal, 175 alert_description: AlertDescription::InternalError, 176 }), 177 None, 178 )) 179 } 180 } 181 generate( &self, state: &mut State, _cache: &HandshakeCache, _cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>182 async fn generate( 183 &self, 184 state: &mut State, 185 _cache: &HandshakeCache, 186 _cfg: &HandshakeConfig, 187 ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> { 188 // Initialize 189 state.cookie = vec![0; COOKIE_LENGTH]; 190 rand::thread_rng().fill(state.cookie.as_mut_slice()); 191 192 //TODO: figure out difference between golang's atom store and rust atom store 193 let zero_epoch = 0; 194 state.local_epoch.store(zero_epoch, Ordering::SeqCst); 195 state.remote_epoch.store(zero_epoch, Ordering::SeqCst); 196 197 state.named_curve = DEFAULT_NAMED_CURVE; 198 state.local_random.populate(); 199 200 Ok(vec![]) 201 } 202 } 203