xref: /webrtc/dtls/src/flight/flight6.rs (revision ffe74184)
1 use super::*;
2 use crate::change_cipher_spec::*;
3 use crate::content::*;
4 use crate::handshake::handshake_message_finished::*;
5 use crate::handshake::*;
6 use crate::prf::*;
7 use crate::record_layer::record_layer_header::*;
8 
9 use async_trait::async_trait;
10 use std::fmt;
11 
12 #[derive(Debug, PartialEq)]
13 pub(crate) struct Flight6;
14 
15 impl fmt::Display for Flight6 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result16     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17         write!(f, "Flight 6")
18     }
19 }
20 
21 #[async_trait]
22 impl Flight for Flight6 {
is_last_send_flight(&self) -> bool23     fn is_last_send_flight(&self) -> bool {
24         true
25     }
26 
parse( &self, _tx: &mut mpsc::Sender<mpsc::Sender<()>>, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)>27     async fn parse(
28         &self,
29         _tx: &mut mpsc::Sender<mpsc::Sender<()>>,
30         state: &mut State,
31         cache: &HandshakeCache,
32         cfg: &HandshakeConfig,
33     ) -> Result<Box<dyn Flight + Send + Sync>, (Option<Alert>, Option<Error>)> {
34         let (_, msgs) = match cache
35             .full_pull_map(
36                 state.handshake_recv_sequence - 1,
37                 &[HandshakeCachePullRule {
38                     typ: HandshakeType::Finished,
39                     epoch: cfg.initial_epoch + 1,
40                     is_client: true,
41                     optional: false,
42                 }],
43             )
44             .await
45         {
46             Ok((seq, msgs)) => (seq, msgs),
47             // No valid message received. Keep reading
48             Err(_) => return Err((None, None)),
49         };
50 
51         if let Some(message) = msgs.get(&HandshakeType::Finished) {
52             match message {
53                 HandshakeMessage::Finished(_) => {}
54                 _ => {
55                     return Err((
56                         Some(Alert {
57                             alert_level: AlertLevel::Fatal,
58                             alert_description: AlertDescription::InternalError,
59                         }),
60                         None,
61                     ))
62                 }
63             };
64         }
65 
66         // Other party retransmitted the last flight.
67         Ok(Box::new(Flight6 {}))
68     }
69 
generate( &self, state: &mut State, cache: &HandshakeCache, cfg: &HandshakeConfig, ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)>70     async fn generate(
71         &self,
72         state: &mut State,
73         cache: &HandshakeCache,
74         cfg: &HandshakeConfig,
75     ) -> Result<Vec<Packet>, (Option<Alert>, Option<Error>)> {
76         let mut pkts = vec![Packet {
77             record: RecordLayer::new(
78                 PROTOCOL_VERSION1_2,
79                 0,
80                 Content::ChangeCipherSpec(ChangeCipherSpec {}),
81             ),
82             should_encrypt: false,
83             reset_local_sequence_number: false,
84         }];
85 
86         if state.local_verify_data.is_empty() {
87             let plain_text = cache
88                 .pull_and_merge(&[
89                     HandshakeCachePullRule {
90                         typ: HandshakeType::ClientHello,
91                         epoch: cfg.initial_epoch,
92                         is_client: true,
93                         optional: false,
94                     },
95                     HandshakeCachePullRule {
96                         typ: HandshakeType::ServerHello,
97                         epoch: cfg.initial_epoch,
98                         is_client: false,
99                         optional: false,
100                     },
101                     HandshakeCachePullRule {
102                         typ: HandshakeType::Certificate,
103                         epoch: cfg.initial_epoch,
104                         is_client: false,
105                         optional: false,
106                     },
107                     HandshakeCachePullRule {
108                         typ: HandshakeType::ServerKeyExchange,
109                         epoch: cfg.initial_epoch,
110                         is_client: false,
111                         optional: false,
112                     },
113                     HandshakeCachePullRule {
114                         typ: HandshakeType::CertificateRequest,
115                         epoch: cfg.initial_epoch,
116                         is_client: false,
117                         optional: false,
118                     },
119                     HandshakeCachePullRule {
120                         typ: HandshakeType::ServerHelloDone,
121                         epoch: cfg.initial_epoch,
122                         is_client: false,
123                         optional: false,
124                     },
125                     HandshakeCachePullRule {
126                         typ: HandshakeType::Certificate,
127                         epoch: cfg.initial_epoch,
128                         is_client: true,
129                         optional: false,
130                     },
131                     HandshakeCachePullRule {
132                         typ: HandshakeType::ClientKeyExchange,
133                         epoch: cfg.initial_epoch,
134                         is_client: true,
135                         optional: false,
136                     },
137                     HandshakeCachePullRule {
138                         typ: HandshakeType::CertificateVerify,
139                         epoch: cfg.initial_epoch,
140                         is_client: true,
141                         optional: false,
142                     },
143                     HandshakeCachePullRule {
144                         typ: HandshakeType::Finished,
145                         epoch: cfg.initial_epoch + 1,
146                         is_client: true,
147                         optional: false,
148                     },
149                 ])
150                 .await;
151 
152             let cipher_suite = state.cipher_suite.lock().await;
153             if let Some(cipher_suite) = &*cipher_suite {
154                 state.local_verify_data = match prf_verify_data_server(
155                     &state.master_secret,
156                     &plain_text,
157                     cipher_suite.hash_func(),
158                 ) {
159                     Ok(data) => data,
160                     Err(err) => {
161                         return Err((
162                             Some(Alert {
163                                 alert_level: AlertLevel::Fatal,
164                                 alert_description: AlertDescription::InternalError,
165                             }),
166                             Some(err),
167                         ))
168                     }
169                 };
170             }
171         }
172 
173         pkts.push(Packet {
174             record: RecordLayer::new(
175                 PROTOCOL_VERSION1_2,
176                 1,
177                 Content::Handshake(Handshake::new(HandshakeMessage::Finished(
178                     HandshakeMessageFinished {
179                         verify_data: state.local_verify_data.clone(),
180                     },
181                 ))),
182             ),
183             should_encrypt: true,
184             reset_local_sequence_number: true,
185         });
186 
187         Ok(pkts)
188     }
189 }
190