xref: /webrtc/dtls/src/conn/mod.rs (revision 5d8fe953)
1 #[cfg(test)]
2 mod conn_test;
3 
4 use crate::alert::*;
5 use crate::application_data::*;
6 use crate::cipher_suite::*;
7 use crate::config::*;
8 use crate::content::*;
9 use crate::curve::named_curve::NamedCurve;
10 use crate::error::*;
11 use crate::extension::extension_use_srtp::*;
12 use crate::flight::flight0::*;
13 use crate::flight::flight1::*;
14 use crate::flight::flight5::*;
15 use crate::flight::flight6::*;
16 use crate::flight::*;
17 use crate::fragment_buffer::*;
18 use crate::handshake::handshake_cache::*;
19 use crate::handshake::handshake_header::HandshakeHeader;
20 use crate::handshake::*;
21 use crate::handshaker::*;
22 use crate::record_layer::record_layer_header::*;
23 use crate::record_layer::*;
24 use crate::signature_hash_algorithm::parse_signature_schemes;
25 use crate::state::*;
26 
27 use util::{replay_detector::*, Conn};
28 
29 use async_trait::async_trait;
30 use log::*;
31 use std::io::{BufReader, BufWriter};
32 use std::marker::{Send, Sync};
33 use std::net::SocketAddr;
34 use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
35 use std::sync::Arc;
36 use tokio::sync::{mpsc, Mutex};
37 use tokio::time::Duration;
38 
39 pub(crate) const INITIAL_TICKER_INTERVAL: Duration = Duration::from_secs(1);
40 pub(crate) const COOKIE_LENGTH: usize = 20;
41 pub(crate) const DEFAULT_NAMED_CURVE: NamedCurve = NamedCurve::X25519;
42 pub(crate) const INBOUND_BUFFER_SIZE: usize = 8192;
43 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
44 pub(crate) const DEFAULT_REPLAY_PROTECTION_WINDOW: usize = 64;
45 
46 pub static INVALID_KEYING_LABELS: &[&str] = &[
47     "client finished",
48     "server finished",
49     "master secret",
50     "key expansion",
51 ];
52 
53 type PacketSendRequest = (Vec<Packet>, Option<mpsc::Sender<Result<()>>>);
54 
55 struct ConnReaderContext {
56     is_client: bool,
57     replay_protection_window: usize,
58     replay_detector: Vec<Box<dyn ReplayDetector + Send>>,
59     decrypted_tx: mpsc::Sender<Result<Vec<u8>>>,
60     encrypted_packets: Vec<Vec<u8>>,
61     fragment_buffer: FragmentBuffer,
62     cache: HandshakeCache,
63     cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
64     remote_epoch: Arc<AtomicU16>,
65     handshake_tx: mpsc::Sender<mpsc::Sender<()>>,
66     handshake_done_rx: mpsc::Receiver<()>,
67     packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
68 }
69 
70 // Conn represents a DTLS connection
71 pub struct DTLSConn {
72     conn: Arc<dyn Conn + Send + Sync>,
73     pub(crate) cache: HandshakeCache, // caching of handshake messages for verifyData generation
74     decrypted_rx: Mutex<mpsc::Receiver<Result<Vec<u8>>>>, // Decrypted Application Data or error, pull by calling `Read`
75     pub(crate) state: State,                              // Internal state
76 
77     handshake_completed_successfully: Arc<AtomicBool>,
78     connection_closed_by_user: bool,
79     // closeLock              sync.Mutex
80     closed: AtomicBool, //  *closer.Closer
81     //handshakeLoopsFinished sync.WaitGroup
82 
83     //readDeadline  :deadline.Deadline,
84     //writeDeadline :deadline.Deadline,
85 
86     //log logging.LeveledLogger
87     /*
88     reading               chan struct{}
89     handshakeRecv         chan chan struct{}
90     cancelHandshaker      func()
91     cancelHandshakeReader func()
92     */
93     pub(crate) current_flight: Box<dyn Flight + Send + Sync>,
94     pub(crate) flights: Option<Vec<Packet>>,
95     pub(crate) cfg: HandshakeConfig,
96     pub(crate) retransmit: bool,
97     pub(crate) handshake_rx: mpsc::Receiver<mpsc::Sender<()>>,
98 
99     pub(crate) packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
100     pub(crate) handle_queue_tx: mpsc::Sender<mpsc::Sender<()>>,
101     pub(crate) handshake_done_tx: Option<mpsc::Sender<()>>,
102 
103     reader_close_tx: Mutex<Option<mpsc::Sender<()>>>,
104 }
105 
106 type UtilResult<T> = std::result::Result<T, util::Error>;
107 
108 #[async_trait]
109 impl Conn for DTLSConn {
connect(&self, _addr: SocketAddr) -> UtilResult<()>110     async fn connect(&self, _addr: SocketAddr) -> UtilResult<()> {
111         Err(util::Error::Other("Not applicable".to_owned()))
112     }
recv(&self, buf: &mut [u8]) -> UtilResult<usize>113     async fn recv(&self, buf: &mut [u8]) -> UtilResult<usize> {
114         self.read(buf, None).await.map_err(util::Error::from_std)
115     }
recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)>116     async fn recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)> {
117         if let Some(raddr) = self.conn.remote_addr() {
118             let n = self.read(buf, None).await.map_err(util::Error::from_std)?;
119             Ok((n, raddr))
120         } else {
121             Err(util::Error::Other(
122                 "No remote address is provided by underlying Conn".to_owned(),
123             ))
124         }
125     }
send(&self, buf: &[u8]) -> UtilResult<usize>126     async fn send(&self, buf: &[u8]) -> UtilResult<usize> {
127         self.write(buf, None).await.map_err(util::Error::from_std)
128     }
send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult<usize>129     async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult<usize> {
130         Err(util::Error::Other("Not applicable".to_owned()))
131     }
local_addr(&self) -> UtilResult<SocketAddr>132     fn local_addr(&self) -> UtilResult<SocketAddr> {
133         self.conn.local_addr()
134     }
remote_addr(&self) -> Option<SocketAddr>135     fn remote_addr(&self) -> Option<SocketAddr> {
136         self.conn.remote_addr()
137     }
close(&self) -> UtilResult<()>138     async fn close(&self) -> UtilResult<()> {
139         self.close().await.map_err(util::Error::from_std)
140     }
141 }
142 
143 impl DTLSConn {
new( conn: Arc<dyn Conn + Send + Sync>, mut config: Config, is_client: bool, initial_state: Option<State>, ) -> Result<Self>144     pub async fn new(
145         conn: Arc<dyn Conn + Send + Sync>,
146         mut config: Config,
147         is_client: bool,
148         initial_state: Option<State>,
149     ) -> Result<Self> {
150         validate_config(is_client, &config)?;
151 
152         let local_cipher_suites: Vec<CipherSuiteId> = parse_cipher_suites(
153             &config.cipher_suites,
154             config.psk.is_none(),
155             config.psk.is_some(),
156         )?
157         .iter()
158         .map(|cs| cs.id())
159         .collect();
160 
161         let sigs: Vec<u16> = config.signature_schemes.iter().map(|x| *x as u16).collect();
162         let local_signature_schemes = parse_signature_schemes(&sigs, config.insecure_hashes)?;
163 
164         let retransmit_interval = if config.flight_interval != Duration::from_secs(0) {
165             config.flight_interval
166         } else {
167             INITIAL_TICKER_INTERVAL
168         };
169 
170         /*
171            loggerFactory := config.LoggerFactory
172            if loggerFactory == nil {
173                loggerFactory = logging.NewDefaultLoggerFactory()
174            }
175 
176            logger := loggerFactory.NewLogger("dtls")
177         */
178         let maximum_transmission_unit = if config.mtu == 0 {
179             DEFAULT_MTU
180         } else {
181             config.mtu
182         };
183 
184         let replay_protection_window = if config.replay_protection_window == 0 {
185             DEFAULT_REPLAY_PROTECTION_WINDOW
186         } else {
187             config.replay_protection_window
188         };
189 
190         let mut server_name = config.server_name.clone();
191 
192         // Use host from conn address when server_name is not provided
193         if is_client && server_name.is_empty() {
194             if let Some(remote_addr) = conn.remote_addr() {
195                 server_name = remote_addr.ip().to_string();
196             } else {
197                 log::warn!("conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now");
198                 server_name = "localhost".to_owned();
199             }
200         }
201 
202         let cfg = HandshakeConfig {
203             local_psk_callback: config.psk.take(),
204             local_psk_identity_hint: config.psk_identity_hint.take(),
205             local_cipher_suites,
206             local_signature_schemes,
207             extended_master_secret: config.extended_master_secret,
208             local_srtp_protection_profiles: config.srtp_protection_profiles.clone(),
209             server_name,
210             client_auth: config.client_auth,
211             local_certificates: config.certificates.clone(),
212             insecure_skip_verify: config.insecure_skip_verify,
213             insecure_verification: config.insecure_verification,
214             verify_peer_certificate: config.verify_peer_certificate.take(),
215             roots_cas: config.roots_cas,
216             client_cert_verifier: if config.client_auth as u8
217                 >= ClientAuthType::VerifyClientCertIfGiven as u8
218             {
219                 Some(rustls::AllowAnyAuthenticatedClient::new(config.client_cas))
220             } else {
221                 None
222             },
223             retransmit_interval,
224             //log: logger,
225             initial_epoch: 0,
226             ..Default::default()
227         };
228 
229         let (state, flight, initial_fsm_state) = if let Some(state) = initial_state {
230             let flight = if is_client {
231                 Box::new(Flight5 {}) as Box<dyn Flight + Send + Sync>
232             } else {
233                 Box::new(Flight6 {}) as Box<dyn Flight + Send + Sync>
234             };
235 
236             (state, flight, HandshakeState::Finished)
237         } else {
238             let flight = if is_client {
239                 Box::new(Flight1 {}) as Box<dyn Flight + Send + Sync>
240             } else {
241                 Box::new(Flight0 {}) as Box<dyn Flight + Send + Sync>
242             };
243 
244             (
245                 State {
246                     is_client,
247                     ..Default::default()
248                 },
249                 flight,
250                 HandshakeState::Preparing,
251             )
252         };
253 
254         let (decrypted_tx, decrypted_rx) = mpsc::channel(1);
255         let (handshake_tx, handshake_rx) = mpsc::channel(1);
256         let (handshake_done_tx, handshake_done_rx) = mpsc::channel(1);
257         let (packet_tx, mut packet_rx) = mpsc::channel(1);
258         let (handle_queue_tx, mut handle_queue_rx) = mpsc::channel(1);
259         let (reader_close_tx, mut reader_close_rx) = mpsc::channel(1);
260 
261         let packet_tx = Arc::new(packet_tx);
262         let packet_tx2 = Arc::clone(&packet_tx);
263         let next_conn_rx = Arc::clone(&conn);
264         let next_conn_tx = Arc::clone(&conn);
265         let cache = HandshakeCache::new();
266         let mut cache1 = cache.clone();
267         let cache2 = cache.clone();
268         let handshake_completed_successfully = Arc::new(AtomicBool::new(false));
269         let handshake_completed_successfully2 = Arc::clone(&handshake_completed_successfully);
270 
271         let mut c = DTLSConn {
272             conn: Arc::clone(&conn),
273             cache,
274             decrypted_rx: Mutex::new(decrypted_rx),
275             state,
276             handshake_completed_successfully,
277             connection_closed_by_user: false,
278             closed: AtomicBool::new(false),
279 
280             current_flight: flight,
281             flights: None,
282             cfg,
283             retransmit: false,
284             handshake_rx,
285             packet_tx,
286             handle_queue_tx,
287             handshake_done_tx: Some(handshake_done_tx),
288             reader_close_tx: Mutex::new(Some(reader_close_tx)),
289         };
290 
291         let cipher_suite1 = Arc::clone(&c.state.cipher_suite);
292         let sequence_number = Arc::clone(&c.state.local_sequence_number);
293 
294         tokio::spawn(async move {
295             loop {
296                 let rx = packet_rx.recv().await;
297                 if let Some(r) = rx {
298                     let (pkt, result_tx) = r;
299 
300                     let result = DTLSConn::handle_outgoing_packets(
301                         &next_conn_tx,
302                         pkt,
303                         &mut cache1,
304                         is_client,
305                         &sequence_number,
306                         &cipher_suite1,
307                         maximum_transmission_unit,
308                     )
309                     .await;
310 
311                     if let Some(tx) = result_tx {
312                         let _ = tx.send(result).await;
313                     }
314                 } else {
315                     trace!("{}: handle_outgoing_packets exit", srv_cli_str(is_client));
316                     break;
317                 }
318             }
319         });
320 
321         let local_epoch = Arc::clone(&c.state.local_epoch);
322         let remote_epoch = Arc::clone(&c.state.remote_epoch);
323         let cipher_suite2 = Arc::clone(&c.state.cipher_suite);
324 
325         tokio::spawn(async move {
326             let mut buf = vec![0u8; INBOUND_BUFFER_SIZE];
327             let mut ctx = ConnReaderContext {
328                 is_client,
329                 replay_protection_window,
330                 replay_detector: vec![],
331                 decrypted_tx,
332                 encrypted_packets: vec![],
333                 fragment_buffer: FragmentBuffer::new(),
334                 cache: cache2,
335                 cipher_suite: cipher_suite2,
336                 remote_epoch,
337                 handshake_tx,
338                 handshake_done_rx,
339                 packet_tx: packet_tx2,
340             };
341 
342             //trace!("before enter read_and_buffer: {}] ", srv_cli_str(is_client));
343             loop {
344                 tokio::select! {
345                     _ = reader_close_rx.recv() => {
346                         trace!(
347                                 "{}: read_and_buffer exit",
348                                 srv_cli_str(ctx.is_client),
349                             );
350                         break;
351                     }
352                     result = DTLSConn::read_and_buffer(
353                                             &mut ctx,
354                                             &next_conn_rx,
355                                             &mut handle_queue_rx,
356                                             &mut buf,
357                                             &local_epoch,
358                                             &handshake_completed_successfully2,
359                                         ) => {
360                         if let Err(err) = result {
361                             trace!(
362                                 "{}: read_and_buffer return err: {}",
363                                 srv_cli_str(is_client),
364                                 err
365                             );
366                             if Error::ErrAlertFatalOrClose == err {
367                                 trace!(
368                                     "{}: read_and_buffer exit with {}",
369                                     srv_cli_str(ctx.is_client),
370                                     err
371                                 );
372 
373                                 break;
374                             }
375                         }
376                     }
377                 }
378             }
379         });
380 
381         // Do handshake
382         c.handshake(initial_fsm_state).await?;
383 
384         trace!("Handshake Completed");
385 
386         Ok(c)
387     }
388 
389     // Read reads data from the connection.
read(&self, p: &mut [u8], duration: Option<Duration>) -> Result<usize>390     pub async fn read(&self, p: &mut [u8], duration: Option<Duration>) -> Result<usize> {
391         if !self.is_handshake_completed_successfully() {
392             return Err(Error::ErrHandshakeInProgress);
393         }
394 
395         let rx = {
396             let mut decrypted_rx = self.decrypted_rx.lock().await;
397             if let Some(d) = duration {
398                 let timer = tokio::time::sleep(d);
399                 tokio::pin!(timer);
400 
401                 tokio::select! {
402                     r = decrypted_rx.recv() => r,
403                     _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
404                 }
405             } else {
406                 decrypted_rx.recv().await
407             }
408         };
409 
410         if let Some(out) = rx {
411             match out {
412                 Ok(val) => {
413                     let n = val.len();
414                     if p.len() < n {
415                         return Err(Error::ErrBufferTooSmall);
416                     }
417                     p[..n].copy_from_slice(&val);
418                     Ok(n)
419                 }
420                 Err(err) => Err(err),
421             }
422         } else {
423             Err(Error::ErrAlertFatalOrClose)
424         }
425     }
426 
427     // Write writes len(p) bytes from p to the DTLS connection
write(&self, p: &[u8], duration: Option<Duration>) -> Result<usize>428     pub async fn write(&self, p: &[u8], duration: Option<Duration>) -> Result<usize> {
429         if self.is_connection_closed() {
430             return Err(Error::ErrConnClosed);
431         }
432 
433         if !self.is_handshake_completed_successfully() {
434             return Err(Error::ErrHandshakeInProgress);
435         }
436 
437         let pkts = vec![Packet {
438             record: RecordLayer::new(
439                 PROTOCOL_VERSION1_2,
440                 self.get_local_epoch(),
441                 Content::ApplicationData(ApplicationData { data: p.to_vec() }),
442             ),
443             should_encrypt: true,
444             reset_local_sequence_number: false,
445         }];
446 
447         if let Some(d) = duration {
448             let timer = tokio::time::sleep(d);
449             tokio::pin!(timer);
450 
451             tokio::select! {
452                 result = self.write_packets(pkts) => {
453                     result?;
454                 }
455                 _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
456             }
457         } else {
458             self.write_packets(pkts).await?;
459         }
460 
461         Ok(p.len())
462     }
463 
464     // Close closes the connection.
close(&self) -> Result<()>465     pub async fn close(&self) -> Result<()> {
466         if !self.closed.load(Ordering::SeqCst) {
467             self.closed.store(true, Ordering::SeqCst);
468 
469             // Discard error from notify() to return non-error on the first user call of Close()
470             // even if the underlying connection is already closed.
471             self.notify(AlertLevel::Warning, AlertDescription::CloseNotify)
472                 .await?;
473 
474             {
475                 let mut reader_close_tx = self.reader_close_tx.lock().await;
476                 reader_close_tx.take();
477             }
478             self.conn.close().await?;
479         }
480 
481         Ok(())
482     }
483 
484     /// connection_state returns basic DTLS details about the connection.
485     /// Note that this replaced the `Export` function of v1.
connection_state(&self) -> State486     pub async fn connection_state(&self) -> State {
487         self.state.clone().await
488     }
489 
490     /// selected_srtpprotection_profile returns the selected SRTPProtectionProfile
selected_srtpprotection_profile(&self) -> SrtpProtectionProfile491     pub fn selected_srtpprotection_profile(&self) -> SrtpProtectionProfile {
492         self.state.srtp_protection_profile
493     }
494 
notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()>495     pub(crate) async fn notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()> {
496         self.write_packets(vec![Packet {
497             record: RecordLayer::new(
498                 PROTOCOL_VERSION1_2,
499                 self.get_local_epoch(),
500                 Content::Alert(Alert {
501                     alert_level: level,
502                     alert_description: desc,
503                 }),
504             ),
505             should_encrypt: self.is_handshake_completed_successfully(),
506             reset_local_sequence_number: false,
507         }])
508         .await
509     }
510 
write_packets(&self, pkts: Vec<Packet>) -> Result<()>511     pub(crate) async fn write_packets(&self, pkts: Vec<Packet>) -> Result<()> {
512         let (tx, mut rx) = mpsc::channel(1);
513 
514         self.packet_tx.send((pkts, Some(tx))).await?;
515 
516         if let Some(result) = rx.recv().await {
517             result
518         } else {
519             Ok(())
520         }
521     }
522 
handle_outgoing_packets( next_conn: &Arc<dyn util::Conn + Send + Sync>, mut pkts: Vec<Packet>, cache: &mut HandshakeCache, is_client: bool, local_sequence_number: &Arc<Mutex<Vec<u64>>>, cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, maximum_transmission_unit: usize, ) -> Result<()>523     async fn handle_outgoing_packets(
524         next_conn: &Arc<dyn util::Conn + Send + Sync>,
525         mut pkts: Vec<Packet>,
526         cache: &mut HandshakeCache,
527         is_client: bool,
528         local_sequence_number: &Arc<Mutex<Vec<u64>>>,
529         cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
530         maximum_transmission_unit: usize,
531     ) -> Result<()> {
532         let mut raw_packets = vec![];
533         for p in &mut pkts {
534             if let Content::Handshake(h) = &p.record.content {
535                 let mut handshake_raw = vec![];
536                 {
537                     let mut writer = BufWriter::<&mut Vec<u8>>::new(handshake_raw.as_mut());
538                     p.record.marshal(&mut writer)?;
539                 }
540                 trace!(
541                     "Send [handshake:{}] -> {} (epoch: {}, seq: {})",
542                     srv_cli_str(is_client),
543                     h.handshake_header.handshake_type.to_string(),
544                     p.record.record_layer_header.epoch,
545                     h.handshake_header.message_sequence
546                 );
547                 cache
548                     .push(
549                         handshake_raw[RECORD_LAYER_HEADER_SIZE..].to_vec(),
550                         p.record.record_layer_header.epoch,
551                         h.handshake_header.message_sequence,
552                         h.handshake_header.handshake_type,
553                         is_client,
554                     )
555                     .await;
556 
557                 let raw_handshake_packets = DTLSConn::process_handshake_packet(
558                     local_sequence_number,
559                     cipher_suite,
560                     maximum_transmission_unit,
561                     p,
562                     h,
563                 )
564                 .await?;
565                 raw_packets.extend_from_slice(&raw_handshake_packets);
566             } else {
567                 /*if let Content::Alert(a) = &p.record.content {
568                     if a.alert_description == AlertDescription::CloseNotify {
569                         closed = true;
570                     }
571                 }*/
572 
573                 let raw_packet =
574                     DTLSConn::process_packet(local_sequence_number, cipher_suite, p).await?;
575                 raw_packets.push(raw_packet);
576             }
577         }
578 
579         if !raw_packets.is_empty() {
580             let compacted_raw_packets =
581                 compact_raw_packets(&raw_packets, maximum_transmission_unit);
582 
583             for compacted_raw_packets in &compacted_raw_packets {
584                 next_conn.send(compacted_raw_packets).await?;
585             }
586         }
587 
588         Ok(())
589     }
590 
process_packet( local_sequence_number: &Arc<Mutex<Vec<u64>>>, cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, p: &mut Packet, ) -> Result<Vec<u8>>591     async fn process_packet(
592         local_sequence_number: &Arc<Mutex<Vec<u64>>>,
593         cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
594         p: &mut Packet,
595     ) -> Result<Vec<u8>> {
596         let epoch = p.record.record_layer_header.epoch as usize;
597         let seq = {
598             let mut lsn = local_sequence_number.lock().await;
599             while lsn.len() <= epoch {
600                 lsn.push(0);
601             }
602 
603             lsn[epoch] += 1;
604             lsn[epoch] - 1
605         };
606         //trace!("{}: seq = {}", srv_cli_str(is_client), seq);
607 
608         if seq > MAX_SEQUENCE_NUMBER {
609             // RFC 6347 Section 4.1.0
610             // The implementation must either abandon an association or rehandshake
611             // prior to allowing the sequence number to wrap.
612             return Err(Error::ErrSequenceNumberOverflow);
613         }
614         p.record.record_layer_header.sequence_number = seq;
615 
616         let mut raw_packet = vec![];
617         {
618             let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_packet.as_mut());
619             p.record.marshal(&mut writer)?;
620         }
621 
622         if p.should_encrypt {
623             let cipher_suite = cipher_suite.lock().await;
624             if let Some(cipher_suite) = &*cipher_suite {
625                 raw_packet = cipher_suite.encrypt(&p.record.record_layer_header, &raw_packet)?;
626             }
627         }
628 
629         Ok(raw_packet)
630     }
631 
process_handshake_packet( local_sequence_number: &Arc<Mutex<Vec<u64>>>, cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, maximum_transmission_unit: usize, p: &Packet, h: &Handshake, ) -> Result<Vec<Vec<u8>>>632     async fn process_handshake_packet(
633         local_sequence_number: &Arc<Mutex<Vec<u64>>>,
634         cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
635         maximum_transmission_unit: usize,
636         p: &Packet,
637         h: &Handshake,
638     ) -> Result<Vec<Vec<u8>>> {
639         let mut raw_packets = vec![];
640 
641         let handshake_fragments = DTLSConn::fragment_handshake(maximum_transmission_unit, h)?;
642 
643         let epoch = p.record.record_layer_header.epoch as usize;
644 
645         let mut lsn = local_sequence_number.lock().await;
646         while lsn.len() <= epoch {
647             lsn.push(0);
648         }
649 
650         for handshake_fragment in &handshake_fragments {
651             let seq = {
652                 lsn[epoch] += 1;
653                 lsn[epoch] - 1
654             };
655             //trace!("seq = {}", seq);
656             if seq > MAX_SEQUENCE_NUMBER {
657                 return Err(Error::ErrSequenceNumberOverflow);
658             }
659 
660             let record_layer_header = RecordLayerHeader {
661                 protocol_version: p.record.record_layer_header.protocol_version,
662                 content_type: p.record.record_layer_header.content_type,
663                 content_len: handshake_fragment.len() as u16,
664                 epoch: p.record.record_layer_header.epoch,
665                 sequence_number: seq,
666             };
667 
668             let mut record_layer_header_bytes = vec![];
669             {
670                 let mut writer = BufWriter::<&mut Vec<u8>>::new(record_layer_header_bytes.as_mut());
671                 record_layer_header.marshal(&mut writer)?;
672             }
673 
674             //p.record.record_layer_header = record_layer_header;
675 
676             let mut raw_packet = vec![];
677             raw_packet.extend_from_slice(&record_layer_header_bytes);
678             raw_packet.extend_from_slice(handshake_fragment);
679             if p.should_encrypt {
680                 let cipher_suite = cipher_suite.lock().await;
681                 if let Some(cipher_suite) = &*cipher_suite {
682                     raw_packet = cipher_suite.encrypt(&record_layer_header, &raw_packet)?;
683                 }
684             }
685 
686             raw_packets.push(raw_packet);
687         }
688 
689         Ok(raw_packets)
690     }
691 
fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>>692     fn fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>> {
693         let mut content = vec![];
694         {
695             let mut writer = BufWriter::<&mut Vec<u8>>::new(content.as_mut());
696             h.handshake_message.marshal(&mut writer)?;
697         }
698 
699         let mut fragmented_handshakes = vec![];
700 
701         let mut content_fragments = split_bytes(&content, maximum_transmission_unit);
702         if content_fragments.is_empty() {
703             content_fragments = vec![vec![]];
704         }
705 
706         let mut offset = 0;
707         for content_fragment in &content_fragments {
708             let content_fragment_len = content_fragment.len();
709 
710             let handshake_header_fragment = HandshakeHeader {
711                 handshake_type: h.handshake_header.handshake_type,
712                 length: h.handshake_header.length,
713                 message_sequence: h.handshake_header.message_sequence,
714                 fragment_offset: offset as u32,
715                 fragment_length: content_fragment_len as u32,
716             };
717 
718             offset += content_fragment_len;
719 
720             let mut handshake_header_fragment_raw = vec![];
721             {
722                 let mut writer =
723                     BufWriter::<&mut Vec<u8>>::new(handshake_header_fragment_raw.as_mut());
724                 handshake_header_fragment.marshal(&mut writer)?;
725             }
726 
727             let mut fragmented_handshake = vec![];
728             fragmented_handshake.extend_from_slice(&handshake_header_fragment_raw);
729             fragmented_handshake.extend_from_slice(content_fragment);
730 
731             fragmented_handshakes.push(fragmented_handshake);
732         }
733 
734         Ok(fragmented_handshakes)
735     }
736 
set_handshake_completed_successfully(&mut self)737     pub(crate) fn set_handshake_completed_successfully(&mut self) {
738         self.handshake_completed_successfully
739             .store(true, Ordering::SeqCst);
740     }
741 
is_handshake_completed_successfully(&self) -> bool742     pub(crate) fn is_handshake_completed_successfully(&self) -> bool {
743         self.handshake_completed_successfully.load(Ordering::SeqCst)
744     }
745 
read_and_buffer( ctx: &mut ConnReaderContext, next_conn: &Arc<dyn util::Conn + Send + Sync>, handle_queue_rx: &mut mpsc::Receiver<mpsc::Sender<()>>, buf: &mut [u8], local_epoch: &Arc<AtomicU16>, handshake_completed_successfully: &Arc<AtomicBool>, ) -> Result<()>746     async fn read_and_buffer(
747         ctx: &mut ConnReaderContext,
748         next_conn: &Arc<dyn util::Conn + Send + Sync>,
749         handle_queue_rx: &mut mpsc::Receiver<mpsc::Sender<()>>,
750         buf: &mut [u8],
751         local_epoch: &Arc<AtomicU16>,
752         handshake_completed_successfully: &Arc<AtomicBool>,
753     ) -> Result<()> {
754         let n = next_conn.recv(buf).await?;
755         let pkts = unpack_datagram(&buf[..n])?;
756         let mut has_handshake = false;
757         for pkt in pkts {
758             let (hs, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, pkt, true).await;
759             if let Some(alert) = alert {
760                 let alert_err = ctx
761                     .packet_tx
762                     .send((
763                         vec![Packet {
764                             record: RecordLayer::new(
765                                 PROTOCOL_VERSION1_2,
766                                 local_epoch.load(Ordering::SeqCst),
767                                 Content::Alert(Alert {
768                                     alert_level: alert.alert_level,
769                                     alert_description: alert.alert_description,
770                                 }),
771                             ),
772                             should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
773                             reset_local_sequence_number: false,
774                         }],
775                         None,
776                     ))
777                     .await;
778 
779                 if let Err(alert_err) = alert_err {
780                     if err.is_none() {
781                         err = Some(Error::Other(alert_err.to_string()));
782                     }
783                 }
784 
785                 if alert.alert_level == AlertLevel::Fatal
786                     || alert.alert_description == AlertDescription::CloseNotify
787                 {
788                     return Err(Error::ErrAlertFatalOrClose);
789                 }
790             }
791 
792             if let Some(err) = err {
793                 return Err(err);
794             }
795 
796             if hs {
797                 has_handshake = true
798             }
799         }
800 
801         if has_handshake {
802             let (done_tx, mut done_rx) = mpsc::channel(1);
803 
804             tokio::select! {
805                 _ = ctx.handshake_tx.send(done_tx) => {
806                     let mut wait_done_rx = true;
807                     while wait_done_rx{
808                         tokio::select!{
809                             _ = done_rx.recv() => {
810                                 // If the other party may retransmit the flight,
811                                 // we should respond even if it not a new message.
812                                 wait_done_rx = false;
813                             }
814                             done = handle_queue_rx.recv() => {
815                                 //trace!("recv handle_queue: {} ", srv_cli_str(ctx.is_client));
816 
817                                 let pkts = ctx.encrypted_packets.drain(..).collect();
818                                 DTLSConn::handle_queued_packets(ctx, local_epoch, handshake_completed_successfully, pkts).await?;
819 
820                                 drop(done);
821                             }
822                         }
823                     }
824                 }
825                 _ = ctx.handshake_done_rx.recv() => {}
826             }
827         }
828 
829         Ok(())
830     }
831 
handle_queued_packets( ctx: &mut ConnReaderContext, local_epoch: &Arc<AtomicU16>, handshake_completed_successfully: &Arc<AtomicBool>, pkts: Vec<Vec<u8>>, ) -> Result<()>832     async fn handle_queued_packets(
833         ctx: &mut ConnReaderContext,
834         local_epoch: &Arc<AtomicU16>,
835         handshake_completed_successfully: &Arc<AtomicBool>,
836         pkts: Vec<Vec<u8>>,
837     ) -> Result<()> {
838         for p in pkts {
839             let (_, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, p, false).await; // don't re-enqueue
840             if let Some(alert) = alert {
841                 let alert_err = ctx
842                     .packet_tx
843                     .send((
844                         vec![Packet {
845                             record: RecordLayer::new(
846                                 PROTOCOL_VERSION1_2,
847                                 local_epoch.load(Ordering::SeqCst),
848                                 Content::Alert(Alert {
849                                     alert_level: alert.alert_level,
850                                     alert_description: alert.alert_description,
851                                 }),
852                             ),
853                             should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
854                             reset_local_sequence_number: false,
855                         }],
856                         None,
857                     ))
858                     .await;
859 
860                 if let Err(alert_err) = alert_err {
861                     if err.is_none() {
862                         err = Some(Error::Other(alert_err.to_string()));
863                     }
864                 }
865                 if alert.alert_level == AlertLevel::Fatal
866                     || alert.alert_description == AlertDescription::CloseNotify
867                 {
868                     return Err(Error::ErrAlertFatalOrClose);
869                 }
870             }
871 
872             if let Some(err) = err {
873                 return Err(err);
874             }
875         }
876 
877         Ok(())
878     }
879 
handle_incoming_packet( ctx: &mut ConnReaderContext, mut pkt: Vec<u8>, enqueue: bool, ) -> (bool, Option<Alert>, Option<Error>)880     async fn handle_incoming_packet(
881         ctx: &mut ConnReaderContext,
882         mut pkt: Vec<u8>,
883         enqueue: bool,
884     ) -> (bool, Option<Alert>, Option<Error>) {
885         let mut reader = BufReader::new(pkt.as_slice());
886         let h = match RecordLayerHeader::unmarshal(&mut reader) {
887             Ok(h) => h,
888             Err(err) => {
889                 // Decode error must be silently discarded
890                 // [RFC6347 Section-4.1.2.7]
891                 debug!(
892                     "{}: discarded broken packet: {}",
893                     srv_cli_str(ctx.is_client),
894                     err
895                 );
896                 return (false, None, None);
897             }
898         };
899 
900         // Validate epoch
901         let epoch = ctx.remote_epoch.load(Ordering::SeqCst);
902         if h.epoch > epoch {
903             if h.epoch > epoch + 1 {
904                 debug!(
905                     "{}: discarded future packet (epoch: {}, seq: {})",
906                     srv_cli_str(ctx.is_client),
907                     h.epoch,
908                     h.sequence_number,
909                 );
910                 return (false, None, None);
911             }
912             if enqueue {
913                 debug!(
914                     "{}: received packet of next epoch, queuing packet",
915                     srv_cli_str(ctx.is_client)
916                 );
917                 ctx.encrypted_packets.push(pkt);
918             }
919             return (false, None, None);
920         }
921 
922         // Anti-replay protection
923         while ctx.replay_detector.len() <= h.epoch as usize {
924             ctx.replay_detector
925                 .push(Box::new(SlidingWindowDetector::new(
926                     ctx.replay_protection_window,
927                     MAX_SEQUENCE_NUMBER,
928                 )));
929         }
930 
931         let ok = ctx.replay_detector[h.epoch as usize].check(h.sequence_number);
932         if !ok {
933             debug!(
934                 "{}: discarded duplicated packet (epoch: {}, seq: {})",
935                 srv_cli_str(ctx.is_client),
936                 h.epoch,
937                 h.sequence_number,
938             );
939             return (false, None, None);
940         }
941 
942         // Decrypt
943         if h.epoch != 0 {
944             let invalid_cipher_suite = {
945                 let cipher_suite = ctx.cipher_suite.lock().await;
946                 if cipher_suite.is_none() {
947                     true
948                 } else if let Some(cipher_suite) = &*cipher_suite {
949                     !cipher_suite.is_initialized()
950                 } else {
951                     false
952                 }
953             };
954             if invalid_cipher_suite {
955                 if enqueue {
956                     debug!(
957                         "{}: handshake not finished, queuing packet",
958                         srv_cli_str(ctx.is_client)
959                     );
960                     ctx.encrypted_packets.push(pkt);
961                 }
962                 return (false, None, None);
963             }
964 
965             let cipher_suite = ctx.cipher_suite.lock().await;
966             if let Some(cipher_suite) = &*cipher_suite {
967                 pkt = match cipher_suite.decrypt(&pkt) {
968                     Ok(pkt) => pkt,
969                     Err(err) => {
970                         debug!("{}: decrypt failed: {}", srv_cli_str(ctx.is_client), err);
971                         return (false, None, None);
972                     }
973                 };
974             }
975         }
976 
977         let is_handshake = match ctx.fragment_buffer.push(&pkt) {
978             Ok(is_handshake) => is_handshake,
979             Err(err) => {
980                 // Decode error must be silently discarded
981                 // [RFC6347 Section-4.1.2.7]
982                 debug!("{}: defragment failed: {}", srv_cli_str(ctx.is_client), err);
983                 return (false, None, None);
984             }
985         };
986         if is_handshake {
987             ctx.replay_detector[h.epoch as usize].accept();
988             while let Ok((out, epoch)) = ctx.fragment_buffer.pop() {
989                 //log::debug!("Extension Debug: out.len()={}", out.len());
990                 let mut reader = BufReader::new(out.as_slice());
991                 let raw_handshake = match Handshake::unmarshal(&mut reader) {
992                     Ok(rh) => {
993                         trace!(
994                             "Recv [handshake:{}] -> {} (epoch: {}, seq: {})",
995                             srv_cli_str(ctx.is_client),
996                             rh.handshake_header.handshake_type.to_string(),
997                             h.epoch,
998                             rh.handshake_header.message_sequence
999                         );
1000                         rh
1001                     }
1002                     Err(err) => {
1003                         debug!(
1004                             "{}: handshake parse failed: {}",
1005                             srv_cli_str(ctx.is_client),
1006                             err
1007                         );
1008                         continue;
1009                     }
1010                 };
1011 
1012                 ctx.cache
1013                     .push(
1014                         out,
1015                         epoch,
1016                         raw_handshake.handshake_header.message_sequence,
1017                         raw_handshake.handshake_header.handshake_type,
1018                         !ctx.is_client,
1019                     )
1020                     .await;
1021             }
1022 
1023             return (true, None, None);
1024         }
1025 
1026         let mut reader = BufReader::new(pkt.as_slice());
1027         let r = match RecordLayer::unmarshal(&mut reader) {
1028             Ok(r) => r,
1029             Err(err) => {
1030                 return (
1031                     false,
1032                     Some(Alert {
1033                         alert_level: AlertLevel::Fatal,
1034                         alert_description: AlertDescription::DecodeError,
1035                     }),
1036                     Some(err),
1037                 );
1038             }
1039         };
1040 
1041         match r.content {
1042             Content::Alert(mut a) => {
1043                 trace!("{}: <- {}", srv_cli_str(ctx.is_client), a.to_string());
1044                 if a.alert_description == AlertDescription::CloseNotify {
1045                     // Respond with a close_notify [RFC5246 Section 7.2.1]
1046                     a = Alert {
1047                         alert_level: AlertLevel::Warning,
1048                         alert_description: AlertDescription::CloseNotify,
1049                     };
1050                 }
1051                 ctx.replay_detector[h.epoch as usize].accept();
1052                 return (
1053                     false,
1054                     Some(a),
1055                     Some(Error::Other(format!("Error of Alert {a}"))),
1056                 );
1057             }
1058             Content::ChangeCipherSpec(_) => {
1059                 let invalid_cipher_suite = {
1060                     let cipher_suite = ctx.cipher_suite.lock().await;
1061                     if cipher_suite.is_none() {
1062                         true
1063                     } else if let Some(cipher_suite) = &*cipher_suite {
1064                         !cipher_suite.is_initialized()
1065                     } else {
1066                         false
1067                     }
1068                 };
1069 
1070                 if invalid_cipher_suite {
1071                     if enqueue {
1072                         debug!(
1073                             "{}: CipherSuite not initialized, queuing packet",
1074                             srv_cli_str(ctx.is_client)
1075                         );
1076                         ctx.encrypted_packets.push(pkt);
1077                     }
1078                     return (false, None, None);
1079                 }
1080 
1081                 let new_remote_epoch = h.epoch + 1;
1082                 trace!(
1083                     "{}: <- ChangeCipherSpec (epoch: {})",
1084                     srv_cli_str(ctx.is_client),
1085                     new_remote_epoch
1086                 );
1087 
1088                 if epoch + 1 == new_remote_epoch {
1089                     ctx.remote_epoch.store(new_remote_epoch, Ordering::SeqCst);
1090                     ctx.replay_detector[h.epoch as usize].accept();
1091                 }
1092             }
1093             Content::ApplicationData(a) => {
1094                 if h.epoch == 0 {
1095                     return (
1096                         false,
1097                         Some(Alert {
1098                             alert_level: AlertLevel::Fatal,
1099                             alert_description: AlertDescription::UnexpectedMessage,
1100                         }),
1101                         Some(Error::ErrApplicationDataEpochZero),
1102                     );
1103                 }
1104 
1105                 ctx.replay_detector[h.epoch as usize].accept();
1106 
1107                 let _ = ctx.decrypted_tx.send(Ok(a.data)).await;
1108                 //TODO
1109                 /*select {
1110                     case self.decrypted < - content.data:
1111                     case < -c.closed.Done():
1112                 }*/
1113             }
1114             _ => {
1115                 return (
1116                     false,
1117                     Some(Alert {
1118                         alert_level: AlertLevel::Fatal,
1119                         alert_description: AlertDescription::UnexpectedMessage,
1120                     }),
1121                     Some(Error::ErrUnhandledContextType),
1122                 );
1123             }
1124         };
1125 
1126         (false, None, None)
1127     }
1128 
is_connection_closed(&self) -> bool1129     fn is_connection_closed(&self) -> bool {
1130         self.closed.load(Ordering::SeqCst)
1131     }
1132 
set_local_epoch(&mut self, epoch: u16)1133     pub(crate) fn set_local_epoch(&mut self, epoch: u16) {
1134         self.state.local_epoch.store(epoch, Ordering::SeqCst);
1135     }
1136 
get_local_epoch(&self) -> u161137     pub(crate) fn get_local_epoch(&self) -> u16 {
1138         self.state.local_epoch.load(Ordering::SeqCst)
1139     }
1140 }
1141 
compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<Vec<u8>>1142 fn compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<Vec<u8>> {
1143     let mut combined_raw_packets = vec![];
1144     let mut current_combined_raw_packet = vec![];
1145 
1146     for raw_packet in raw_packets {
1147         if !current_combined_raw_packet.is_empty()
1148             && current_combined_raw_packet.len() + raw_packet.len() >= maximum_transmission_unit
1149         {
1150             combined_raw_packets.push(current_combined_raw_packet);
1151             current_combined_raw_packet = vec![];
1152         }
1153         current_combined_raw_packet.extend_from_slice(raw_packet);
1154     }
1155 
1156     combined_raw_packets.push(current_combined_raw_packet);
1157 
1158     combined_raw_packets
1159 }
1160 
split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>>1161 fn split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>> {
1162     let mut splits = vec![];
1163     let num_bytes = bytes.len();
1164     for i in (0..num_bytes).step_by(split_len) {
1165         let mut j = i + split_len;
1166         if j > num_bytes {
1167             j = num_bytes;
1168         }
1169 
1170         splits.push(bytes[i..j].to_vec());
1171     }
1172 
1173     splits
1174 }
1175