xref: /webrtc/dtls/src/flight/flight2.rs (revision ffe74184)
1 use super::flight0::*;
2 use super::flight4::*;
3 use super::*;
4 use crate::content::*;
5 use crate::error::Error;
6 use crate::handshake::handshake_message_hello_verify_request::*;
7 use crate::handshake::*;
8 use crate::record_layer::record_layer_header::*;
9 
10 use async_trait::async_trait;
11 use std::fmt;
12 
13 #[derive(Debug, PartialEq)]
14 pub(crate) struct Flight2;
15 
16 impl fmt::Display for Flight2 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result17     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18         write!(f, "Flight 2")
19     }
20 }
21 
22 #[async_trait]
23 impl Flight for Flight2 {
has_retransmit(&self) -> bool24     fn has_retransmit(&self) -> bool {
25         false
26     }
27 
parse( &self, tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>28     async fn parse(
29         &self,
30         tx: &mut mpsc::Sender<mpsc::Sender<()>>,
31         state: &mut State,
32         cache: &HandshakeCache,
33         cfg: &HandshakeConfig,
34     ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> {
35         let (seq, msgs) = match cache
36             .full_pull_map(
37                 state.handshake_recv_sequence,
38                 &[HandshakeCachePullRule {
39                     typ: HandshakeType::ClientHello,
40                     epoch: cfg.initial_epoch,
41                     is_client: true,
42                     optional: false,
43                 }],
44             )
45             .await
46         {
47             // No valid message received. Keep reading
48             Ok((seq, msgs)) => (seq, msgs),
49 
50             // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped.
51             // Parse as flight 0 in this case.
52             Err(_) => return Flight0 {}.parse(tx, state, cache, cfg).await,
53         };
54 
55         state.handshake_recv_sequence = seq;
56 
57         if let Some(message) = msgs.get(&HandshakeType::ClientHello) {
58             // Validate type
59             let client_hello = match message {
60                 HandshakeMessage::ClientHello(client_hello) => client_hello,
61                 _ => {
62                     return Err((
63                         Some(Alert {
64                             alert_level: AlertLevel::Fatal,
65                             alert_description: AlertDescription::InternalError,
66                         }),
67                         None,
68                     ))
69                 }
70             };
71 
72             if client_hello.version != PROTOCOL_VERSION1_2 {
73                 return Err((
74                     Some(Alert {
75                         alert_level: AlertLevel::Fatal,
76                         alert_description: AlertDescription::ProtocolVersion,
77                     }),
78                     Some(Error::ErrUnsupportedProtocolVersion),
79                 ));
80             }
81 
82             if client_hello.cookie.is_empty() {
83                 return Err((None, None));
84             }
85 
86             if state.cookie != client_hello.cookie {
87                 return Err((
88                     Some(Alert {
89                         alert_level: AlertLevel::Fatal,
90                         alert_description: AlertDescription::AccessDenied,
91                     }),
92                     Some(Error::ErrCookieMismatch),
93                 ));
94             }
95 
96             Ok(Box::new(Flight4 {}))
97         } else {
98             Err((
99                 Some(Alert {
100                     alert_level: AlertLevel::Fatal,
101                     alert_description: AlertDescription::InternalError,
102                 }),
103                 None,
104             ))
105         }
106     }
107 
generate( &self, state: &mut State, _cache: &HandshakeCache, _cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>108     async fn generate(
109         &self,
110         state: &mut State,
111         _cache: &HandshakeCache,
112         _cfg: &HandshakeConfig,
113     ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> {
114         state.handshake_send_sequence = 0;
115         Ok(vec![Packet {
116             record: RecordLayer::new(
117                 PROTOCOL_VERSION1_2,
118                 0,
119                 Content::Handshake(Handshake::new(HandshakeMessage::HelloVerifyRequest(
120                     HandshakeMessageHelloVerifyRequest {
121                         version: PROTOCOL_VERSION1_2,
122                         cookie: state.cookie.clone(),
123                     },
124                 ))),
125             ),
126             should_encrypt: false,
127             reset_local_sequence_number: false,
128         }])
129     }
130 }
131