xref: /webrtc/dtls/src/flight/flight0.rs (revision ffe74184)
1 use super::flight2::*;
2 use super::*;
3 use crate::config::*;
4 use crate::conn::*;
5 use crate::error::Error;
6 use crate::extension::*;
7 use crate::handshake::*;
8 use crate::record_layer::record_layer_header::*;
9 use crate::*;
10 
11 use async_trait::async_trait;
12 use rand::Rng;
13 use std::fmt;
14 use std::sync::atomic::Ordering;
15 
16 #[derive(Debug, PartialEq)]
17 pub(crate) struct Flight0;
18 
19 impl fmt::Display for Flight0 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result20     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21         write!(f, "Flight 0")
22     }
23 }
24 
25 #[async_trait]
26 impl Flight for Flight0 {
parse( &self, _tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>27     async fn parse(
28         &self,
29         _tx: &mut mpsc::Sender<mpsc::Sender<()>>,
30         state: &mut State,
31         cache: &HandshakeCache,
32         cfg: &HandshakeConfig,
33     ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> {
34         let (seq, msgs) = match cache
35             .full_pull_map(
36                 0,
37                 &[HandshakeCachePullRule {
38                     typ: HandshakeType::ClientHello,
39                     epoch: cfg.initial_epoch,
40                     is_client: true,
41                     optional: false,
42                 }],
43             )
44             .await
45         {
46             Ok((seq, msgs)) => (seq, msgs),
47             Err(_) => return Err((None, None)),
48         };
49 
50         state.handshake_recv_sequence = seq;
51 
52         if let Some(message) = msgs.get(&HandshakeType::ClientHello) {
53             // Validate type
54             let client_hello = match message {
55                 HandshakeMessage::ClientHello(client_hello) => client_hello,
56                 _ => {
57                     return Err((
58                         Some(Alert {
59                             alert_level: AlertLevel::Fatal,
60                             alert_description: AlertDescription::InternalError,
61                         }),
62                         None,
63                     ))
64                 }
65             };
66 
67             if client_hello.version != PROTOCOL_VERSION1_2 {
68                 return Err((
69                     Some(Alert {
70                         alert_level: AlertLevel::Fatal,
71                         alert_description: AlertDescription::ProtocolVersion,
72                     }),
73                     Some(Error::ErrUnsupportedProtocolVersion),
74                 ));
75             }
76 
77             state.remote_random = client_hello.random.clone();
78 
79             if let Ok(id) =
80                 find_matching_cipher_suite(&client_hello.cipher_suites, &cfg.local_cipher_suites)
81             {
82                 if let Ok(cipher_suite) = cipher_suite_for_id(id) {
83                     log::debug!(
84                         "[handshake:{}] use cipher suite: {}",
85                         srv_cli_str(state.is_client),
86                         cipher_suite.to_string()
87                     );
88                     let mut cs = state.cipher_suite.lock().await;
89                     *cs = Some(cipher_suite);
90                 }
91             } else {
92                 return Err((
93                     Some(Alert {
94                         alert_level: AlertLevel::Fatal,
95                         alert_description: AlertDescription::InsufficientSecurity,
96                     }),
97                     Some(Error::ErrCipherSuiteNoIntersection),
98                 ));
99             }
100 
101             for extension in &client_hello.extensions {
102                 match extension {
103                     Extension::SupportedEllipticCurves(e) => {
104                         if e.elliptic_curves.is_empty() {
105                             return Err((
106                                 Some(Alert {
107                                     alert_level: AlertLevel::Fatal,
108                                     alert_description: AlertDescription::InsufficientSecurity,
109                                 }),
110                                 Some(Error::ErrNoSupportedEllipticCurves),
111                             ));
112                         }
113                         state.named_curve = e.elliptic_curves[0];
114                     }
115                     Extension::UseSrtp(e) => {
116                         if let Ok(profile) = find_matching_srtp_profile(
117                             &e.protection_profiles,
118                             &cfg.local_srtp_protection_profiles,
119                         ) {
120                             state.srtp_protection_profile = profile;
121                         } else {
122                             return Err((
123                                 Some(Alert {
124                                     alert_level: AlertLevel::Fatal,
125                                     alert_description: AlertDescription::InsufficientSecurity,
126                                 }),
127                                 Some(Error::ErrServerNoMatchingSrtpProfile),
128                             ));
129                         }
130                     }
131                     Extension::UseExtendedMasterSecret(_) => {
132                         if cfg.extended_master_secret != ExtendedMasterSecretType::Disable {
133                             state.extended_master_secret = true;
134                         }
135                     }
136                     Extension::ServerName(e) => {
137                         state.server_name = e.server_name.clone(); // remote server name
138                     }
139                     _ => {}
140                 }
141             }
142 
143             if cfg.extended_master_secret == ExtendedMasterSecretType::Require
144                 && !state.extended_master_secret
145             {
146                 return Err((
147                     Some(Alert {
148                         alert_level: AlertLevel::Fatal,
149                         alert_description: AlertDescription::InsufficientSecurity,
150                     }),
151                     Some(Error::ErrServerRequiredButNoClientEms),
152                 ));
153             }
154 
155             if state.local_keypair.is_none() {
156                 state.local_keypair = match state.named_curve.generate_keypair() {
157                     Ok(local_keypar) => Some(local_keypar),
158                     Err(err) => {
159                         return Err((
160                             Some(Alert {
161                                 alert_level: AlertLevel::Fatal,
162                                 alert_description: AlertDescription::IllegalParameter,
163                             }),
164                             Some(err),
165                         ))
166                     }
167                 };
168             }
169 
170             Ok(Box::new(Flight2 {}))
171         } else {
172             Err((
173                 Some(Alert {
174                     alert_level: AlertLevel::Fatal,
175                     alert_description: AlertDescription::InternalError,
176                 }),
177                 None,
178             ))
179         }
180     }
181 
generate( &self, state: &mut State, _cache: &HandshakeCache, _cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>182     async fn generate(
183         &self,
184         state: &mut State,
185         _cache: &HandshakeCache,
186         _cfg: &HandshakeConfig,
187     ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> {
188         // Initialize
189         state.cookie = vec![0; COOKIE_LENGTH];
190         rand::thread_rng().fill(state.cookie.as_mut_slice());
191 
192         //TODO: figure out difference between golang's atom store and rust atom store
193         let zero_epoch = 0;
194         state.local_epoch.store(zero_epoch, Ordering::SeqCst);
195         state.remote_epoch.store(zero_epoch, Ordering::SeqCst);
196 
197         state.named_curve = DEFAULT_NAMED_CURVE;
198         state.local_random.populate();
199 
200         Ok(vec![])
201     }
202 }
203