1 #[cfg(test)]
2 mod conn_test;
3
4 use crate::alert::*;
5 use crate::application_data::*;
6 use crate::cipher_suite::*;
7 use crate::config::*;
8 use crate::content::*;
9 use crate::curve::named_curve::NamedCurve;
10 use crate::error::*;
11 use crate::extension::extension_use_srtp::*;
12 use crate::flight::flight0::*;
13 use crate::flight::flight1::*;
14 use crate::flight::flight5::*;
15 use crate::flight::flight6::*;
16 use crate::flight::*;
17 use crate::fragment_buffer::*;
18 use crate::handshake::handshake_cache::*;
19 use crate::handshake::handshake_header::HandshakeHeader;
20 use crate::handshake::*;
21 use crate::handshaker::*;
22 use crate::record_layer::record_layer_header::*;
23 use crate::record_layer::*;
24 use crate::signature_hash_algorithm::parse_signature_schemes;
25 use crate::state::*;
26
27 use util::{replay_detector::*, Conn};
28
29 use async_trait::async_trait;
30 use log::*;
31 use std::io::{BufReader, BufWriter};
32 use std::marker::{Send, Sync};
33 use std::net::SocketAddr;
34 use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
35 use std::sync::Arc;
36 use tokio::sync::{mpsc, Mutex};
37 use tokio::time::Duration;
38
39 pub(crate) const INITIAL_TICKER_INTERVAL: Duration = Duration::from_secs(1);
40 pub(crate) const COOKIE_LENGTH: usize = 20;
41 pub(crate) const DEFAULT_NAMED_CURVE: NamedCurve = NamedCurve::X25519;
42 pub(crate) const INBOUND_BUFFER_SIZE: usize = 8192;
43 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
44 pub(crate) const DEFAULT_REPLAY_PROTECTION_WINDOW: usize = 64;
45
46 pub static INVALID_KEYING_LABELS: &[&str] = &[
47 "client finished",
48 "server finished",
49 "master secret",
50 "key expansion",
51 ];
52
53 type PacketSendRequest = (Vec<Packet>, Option<mpsc::Sender<Result<()>>>);
54
55 struct ConnReaderContext {
56 is_client: bool,
57 replay_protection_window: usize,
58 replay_detector: Vec<Box<dyn ReplayDetector + Send>>,
59 decrypted_tx: mpsc::Sender<Result<Vec<u8>>>,
60 encrypted_packets: Vec<Vec<u8>>,
61 fragment_buffer: FragmentBuffer,
62 cache: HandshakeCache,
63 cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
64 remote_epoch: Arc<AtomicU16>,
65 handshake_tx: mpsc::Sender<mpsc::Sender<()>>,
66 handshake_done_rx: mpsc::Receiver<()>,
67 packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
68 }
69
70 // Conn represents a DTLS connection
71 pub struct DTLSConn {
72 conn: Arc<dyn Conn + Send + Sync>,
73 pub(crate) cache: HandshakeCache, // caching of handshake messages for verifyData generation
74 decrypted_rx: Mutex<mpsc::Receiver<Result<Vec<u8>>>>, // Decrypted Application Data or error, pull by calling `Read`
75 pub(crate) state: State, // Internal state
76
77 handshake_completed_successfully: Arc<AtomicBool>,
78 connection_closed_by_user: bool,
79 // closeLock sync.Mutex
80 closed: AtomicBool, // *closer.Closer
81 //handshakeLoopsFinished sync.WaitGroup
82
83 //readDeadline :deadline.Deadline,
84 //writeDeadline :deadline.Deadline,
85
86 //log logging.LeveledLogger
87 /*
88 reading chan struct{}
89 handshakeRecv chan chan struct{}
90 cancelHandshaker func()
91 cancelHandshakeReader func()
92 */
93 pub(crate) current_flight: Box<dyn Flight + Send + Sync>,
94 pub(crate) flights: Option<Vec<Packet>>,
95 pub(crate) cfg: HandshakeConfig,
96 pub(crate) retransmit: bool,
97 pub(crate) handshake_rx: mpsc::Receiver<mpsc::Sender<()>>,
98
99 pub(crate) packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
100 pub(crate) handle_queue_tx: mpsc::Sender<mpsc::Sender<()>>,
101 pub(crate) handshake_done_tx: Option<mpsc::Sender<()>>,
102
103 reader_close_tx: Mutex<Option<mpsc::Sender<()>>>,
104 }
105
106 type UtilResult<T> = std::result::Result<T, util::Error>;
107
108 #[async_trait]
109 impl Conn for DTLSConn {
connect(&self, _addr: SocketAddr) -> UtilResult<()>110 async fn connect(&self, _addr: SocketAddr) -> UtilResult<()> {
111 Err(util::Error::Other("Not applicable".to_owned()))
112 }
recv(&self, buf: &mut [u8]) -> UtilResult<usize>113 async fn recv(&self, buf: &mut [u8]) -> UtilResult<usize> {
114 self.read(buf, None).await.map_err(util::Error::from_std)
115 }
recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)>116 async fn recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)> {
117 if let Some(raddr) = self.conn.remote_addr() {
118 let n = self.read(buf, None).await.map_err(util::Error::from_std)?;
119 Ok((n, raddr))
120 } else {
121 Err(util::Error::Other(
122 "No remote address is provided by underlying Conn".to_owned(),
123 ))
124 }
125 }
send(&self, buf: &[u8]) -> UtilResult<usize>126 async fn send(&self, buf: &[u8]) -> UtilResult<usize> {
127 self.write(buf, None).await.map_err(util::Error::from_std)
128 }
send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult<usize>129 async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult<usize> {
130 Err(util::Error::Other("Not applicable".to_owned()))
131 }
local_addr(&self) -> UtilResult<SocketAddr>132 fn local_addr(&self) -> UtilResult<SocketAddr> {
133 self.conn.local_addr()
134 }
remote_addr(&self) -> Option<SocketAddr>135 fn remote_addr(&self) -> Option<SocketAddr> {
136 self.conn.remote_addr()
137 }
close(&self) -> UtilResult<()>138 async fn close(&self) -> UtilResult<()> {
139 self.close().await.map_err(util::Error::from_std)
140 }
141 }
142
143 impl DTLSConn {
new( conn: Arc<dyn Conn + Send + Sync>, mut config: Config, is_client: bool, initial_state: Option<State>, ) -> Result<Self>144 pub async fn new(
145 conn: Arc<dyn Conn + Send + Sync>,
146 mut config: Config,
147 is_client: bool,
148 initial_state: Option<State>,
149 ) -> Result<Self> {
150 validate_config(is_client, &config)?;
151
152 let local_cipher_suites: Vec<CipherSuiteId> = parse_cipher_suites(
153 &config.cipher_suites,
154 config.psk.is_none(),
155 config.psk.is_some(),
156 )?
157 .iter()
158 .map(|cs| cs.id())
159 .collect();
160
161 let sigs: Vec<u16> = config.signature_schemes.iter().map(|x| *x as u16).collect();
162 let local_signature_schemes = parse_signature_schemes(&sigs, config.insecure_hashes)?;
163
164 let retransmit_interval = if config.flight_interval != Duration::from_secs(0) {
165 config.flight_interval
166 } else {
167 INITIAL_TICKER_INTERVAL
168 };
169
170 /*
171 loggerFactory := config.LoggerFactory
172 if loggerFactory == nil {
173 loggerFactory = logging.NewDefaultLoggerFactory()
174 }
175
176 logger := loggerFactory.NewLogger("dtls")
177 */
178 let maximum_transmission_unit = if config.mtu == 0 {
179 DEFAULT_MTU
180 } else {
181 config.mtu
182 };
183
184 let replay_protection_window = if config.replay_protection_window == 0 {
185 DEFAULT_REPLAY_PROTECTION_WINDOW
186 } else {
187 config.replay_protection_window
188 };
189
190 let mut server_name = config.server_name.clone();
191
192 // Use host from conn address when server_name is not provided
193 if is_client && server_name.is_empty() {
194 if let Some(remote_addr) = conn.remote_addr() {
195 server_name = remote_addr.ip().to_string();
196 } else {
197 log::warn!("conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now");
198 server_name = "localhost".to_owned();
199 }
200 }
201
202 let cfg = HandshakeConfig {
203 local_psk_callback: config.psk.take(),
204 local_psk_identity_hint: config.psk_identity_hint.take(),
205 local_cipher_suites,
206 local_signature_schemes,
207 extended_master_secret: config.extended_master_secret,
208 local_srtp_protection_profiles: config.srtp_protection_profiles.clone(),
209 server_name,
210 client_auth: config.client_auth,
211 local_certificates: config.certificates.clone(),
212 insecure_skip_verify: config.insecure_skip_verify,
213 insecure_verification: config.insecure_verification,
214 verify_peer_certificate: config.verify_peer_certificate.take(),
215 roots_cas: config.roots_cas,
216 client_cert_verifier: if config.client_auth as u8
217 >= ClientAuthType::VerifyClientCertIfGiven as u8
218 {
219 Some(rustls::AllowAnyAuthenticatedClient::new(config.client_cas))
220 } else {
221 None
222 },
223 retransmit_interval,
224 //log: logger,
225 initial_epoch: 0,
226 ..Default::default()
227 };
228
229 let (state, flight, initial_fsm_state) = if let Some(state) = initial_state {
230 let flight = if is_client {
231 Box::new(Flight5 {}) as Box<dyn Flight + Send + Sync>
232 } else {
233 Box::new(Flight6 {}) as Box<dyn Flight + Send + Sync>
234 };
235
236 (state, flight, HandshakeState::Finished)
237 } else {
238 let flight = if is_client {
239 Box::new(Flight1 {}) as Box<dyn Flight + Send + Sync>
240 } else {
241 Box::new(Flight0 {}) as Box<dyn Flight + Send + Sync>
242 };
243
244 (
245 State {
246 is_client,
247 ..Default::default()
248 },
249 flight,
250 HandshakeState::Preparing,
251 )
252 };
253
254 let (decrypted_tx, decrypted_rx) = mpsc::channel(1);
255 let (handshake_tx, handshake_rx) = mpsc::channel(1);
256 let (handshake_done_tx, handshake_done_rx) = mpsc::channel(1);
257 let (packet_tx, mut packet_rx) = mpsc::channel(1);
258 let (handle_queue_tx, mut handle_queue_rx) = mpsc::channel(1);
259 let (reader_close_tx, mut reader_close_rx) = mpsc::channel(1);
260
261 let packet_tx = Arc::new(packet_tx);
262 let packet_tx2 = Arc::clone(&packet_tx);
263 let next_conn_rx = Arc::clone(&conn);
264 let next_conn_tx = Arc::clone(&conn);
265 let cache = HandshakeCache::new();
266 let mut cache1 = cache.clone();
267 let cache2 = cache.clone();
268 let handshake_completed_successfully = Arc::new(AtomicBool::new(false));
269 let handshake_completed_successfully2 = Arc::clone(&handshake_completed_successfully);
270
271 let mut c = DTLSConn {
272 conn: Arc::clone(&conn),
273 cache,
274 decrypted_rx: Mutex::new(decrypted_rx),
275 state,
276 handshake_completed_successfully,
277 connection_closed_by_user: false,
278 closed: AtomicBool::new(false),
279
280 current_flight: flight,
281 flights: None,
282 cfg,
283 retransmit: false,
284 handshake_rx,
285 packet_tx,
286 handle_queue_tx,
287 handshake_done_tx: Some(handshake_done_tx),
288 reader_close_tx: Mutex::new(Some(reader_close_tx)),
289 };
290
291 let cipher_suite1 = Arc::clone(&c.state.cipher_suite);
292 let sequence_number = Arc::clone(&c.state.local_sequence_number);
293
294 tokio::spawn(async move {
295 loop {
296 let rx = packet_rx.recv().await;
297 if let Some(r) = rx {
298 let (pkt, result_tx) = r;
299
300 let result = DTLSConn::handle_outgoing_packets(
301 &next_conn_tx,
302 pkt,
303 &mut cache1,
304 is_client,
305 &sequence_number,
306 &cipher_suite1,
307 maximum_transmission_unit,
308 )
309 .await;
310
311 if let Some(tx) = result_tx {
312 let _ = tx.send(result).await;
313 }
314 } else {
315 trace!("{}: handle_outgoing_packets exit", srv_cli_str(is_client));
316 break;
317 }
318 }
319 });
320
321 let local_epoch = Arc::clone(&c.state.local_epoch);
322 let remote_epoch = Arc::clone(&c.state.remote_epoch);
323 let cipher_suite2 = Arc::clone(&c.state.cipher_suite);
324
325 tokio::spawn(async move {
326 let mut buf = vec![0u8; INBOUND_BUFFER_SIZE];
327 let mut ctx = ConnReaderContext {
328 is_client,
329 replay_protection_window,
330 replay_detector: vec![],
331 decrypted_tx,
332 encrypted_packets: vec![],
333 fragment_buffer: FragmentBuffer::new(),
334 cache: cache2,
335 cipher_suite: cipher_suite2,
336 remote_epoch,
337 handshake_tx,
338 handshake_done_rx,
339 packet_tx: packet_tx2,
340 };
341
342 //trace!("before enter read_and_buffer: {}] ", srv_cli_str(is_client));
343 loop {
344 tokio::select! {
345 _ = reader_close_rx.recv() => {
346 trace!(
347 "{}: read_and_buffer exit",
348 srv_cli_str(ctx.is_client),
349 );
350 break;
351 }
352 result = DTLSConn::read_and_buffer(
353 &mut ctx,
354 &next_conn_rx,
355 &mut handle_queue_rx,
356 &mut buf,
357 &local_epoch,
358 &handshake_completed_successfully2,
359 ) => {
360 if let Err(err) = result {
361 trace!(
362 "{}: read_and_buffer return err: {}",
363 srv_cli_str(is_client),
364 err
365 );
366 if Error::ErrAlertFatalOrClose == err {
367 trace!(
368 "{}: read_and_buffer exit with {}",
369 srv_cli_str(ctx.is_client),
370 err
371 );
372
373 break;
374 }
375 }
376 }
377 }
378 }
379 });
380
381 // Do handshake
382 c.handshake(initial_fsm_state).await?;
383
384 trace!("Handshake Completed");
385
386 Ok(c)
387 }
388
389 // Read reads data from the connection.
read(&self, p: &mut [u8], duration: Option<Duration>) -> Result<usize>390 pub async fn read(&self, p: &mut [u8], duration: Option<Duration>) -> Result<usize> {
391 if !self.is_handshake_completed_successfully() {
392 return Err(Error::ErrHandshakeInProgress);
393 }
394
395 let rx = {
396 let mut decrypted_rx = self.decrypted_rx.lock().await;
397 if let Some(d) = duration {
398 let timer = tokio::time::sleep(d);
399 tokio::pin!(timer);
400
401 tokio::select! {
402 r = decrypted_rx.recv() => r,
403 _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
404 }
405 } else {
406 decrypted_rx.recv().await
407 }
408 };
409
410 if let Some(out) = rx {
411 match out {
412 Ok(val) => {
413 let n = val.len();
414 if p.len() < n {
415 return Err(Error::ErrBufferTooSmall);
416 }
417 p[..n].copy_from_slice(&val);
418 Ok(n)
419 }
420 Err(err) => Err(err),
421 }
422 } else {
423 Err(Error::ErrAlertFatalOrClose)
424 }
425 }
426
427 // Write writes len(p) bytes from p to the DTLS connection
write(&self, p: &[u8], duration: Option<Duration>) -> Result<usize>428 pub async fn write(&self, p: &[u8], duration: Option<Duration>) -> Result<usize> {
429 if self.is_connection_closed() {
430 return Err(Error::ErrConnClosed);
431 }
432
433 if !self.is_handshake_completed_successfully() {
434 return Err(Error::ErrHandshakeInProgress);
435 }
436
437 let pkts = vec![Packet {
438 record: RecordLayer::new(
439 PROTOCOL_VERSION1_2,
440 self.get_local_epoch(),
441 Content::ApplicationData(ApplicationData { data: p.to_vec() }),
442 ),
443 should_encrypt: true,
444 reset_local_sequence_number: false,
445 }];
446
447 if let Some(d) = duration {
448 let timer = tokio::time::sleep(d);
449 tokio::pin!(timer);
450
451 tokio::select! {
452 result = self.write_packets(pkts) => {
453 result?;
454 }
455 _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
456 }
457 } else {
458 self.write_packets(pkts).await?;
459 }
460
461 Ok(p.len())
462 }
463
464 // Close closes the connection.
close(&self) -> Result<()>465 pub async fn close(&self) -> Result<()> {
466 if !self.closed.load(Ordering::SeqCst) {
467 self.closed.store(true, Ordering::SeqCst);
468
469 // Discard error from notify() to return non-error on the first user call of Close()
470 // even if the underlying connection is already closed.
471 self.notify(AlertLevel::Warning, AlertDescription::CloseNotify)
472 .await?;
473
474 {
475 let mut reader_close_tx = self.reader_close_tx.lock().await;
476 reader_close_tx.take();
477 }
478 self.conn.close().await?;
479 }
480
481 Ok(())
482 }
483
484 /// connection_state returns basic DTLS details about the connection.
485 /// Note that this replaced the `Export` function of v1.
connection_state(&self) -> State486 pub async fn connection_state(&self) -> State {
487 self.state.clone().await
488 }
489
490 /// selected_srtpprotection_profile returns the selected SRTPProtectionProfile
selected_srtpprotection_profile(&self) -> SrtpProtectionProfile491 pub fn selected_srtpprotection_profile(&self) -> SrtpProtectionProfile {
492 self.state.srtp_protection_profile
493 }
494
notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()>495 pub(crate) async fn notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()> {
496 self.write_packets(vec![Packet {
497 record: RecordLayer::new(
498 PROTOCOL_VERSION1_2,
499 self.get_local_epoch(),
500 Content::Alert(Alert {
501 alert_level: level,
502 alert_description: desc,
503 }),
504 ),
505 should_encrypt: self.is_handshake_completed_successfully(),
506 reset_local_sequence_number: false,
507 }])
508 .await
509 }
510
write_packets(&self, pkts: Vec<Packet>) -> Result<()>511 pub(crate) async fn write_packets(&self, pkts: Vec<Packet>) -> Result<()> {
512 let (tx, mut rx) = mpsc::channel(1);
513
514 self.packet_tx.send((pkts, Some(tx))).await?;
515
516 if let Some(result) = rx.recv().await {
517 result
518 } else {
519 Ok(())
520 }
521 }
522
handle_outgoing_packets( next_conn: &Arc<dyn util::Conn + Send + Sync>, mut pkts: Vec<Packet>, cache: &mut HandshakeCache, is_client: bool, local_sequence_number: &Arc<Mutex<Vec<u64>>>, cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, maximum_transmission_unit: usize, ) -> Result<()>523 async fn handle_outgoing_packets(
524 next_conn: &Arc<dyn util::Conn + Send + Sync>,
525 mut pkts: Vec<Packet>,
526 cache: &mut HandshakeCache,
527 is_client: bool,
528 local_sequence_number: &Arc<Mutex<Vec<u64>>>,
529 cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
530 maximum_transmission_unit: usize,
531 ) -> Result<()> {
532 let mut raw_packets = vec![];
533 for p in &mut pkts {
534 if let Content::Handshake(h) = &p.record.content {
535 let mut handshake_raw = vec![];
536 {
537 let mut writer = BufWriter::<&mut Vec<u8>>::new(handshake_raw.as_mut());
538 p.record.marshal(&mut writer)?;
539 }
540 trace!(
541 "Send [handshake:{}] -> {} (epoch: {}, seq: {})",
542 srv_cli_str(is_client),
543 h.handshake_header.handshake_type.to_string(),
544 p.record.record_layer_header.epoch,
545 h.handshake_header.message_sequence
546 );
547 cache
548 .push(
549 handshake_raw[RECORD_LAYER_HEADER_SIZE..].to_vec(),
550 p.record.record_layer_header.epoch,
551 h.handshake_header.message_sequence,
552 h.handshake_header.handshake_type,
553 is_client,
554 )
555 .await;
556
557 let raw_handshake_packets = DTLSConn::process_handshake_packet(
558 local_sequence_number,
559 cipher_suite,
560 maximum_transmission_unit,
561 p,
562 h,
563 )
564 .await?;
565 raw_packets.extend_from_slice(&raw_handshake_packets);
566 } else {
567 /*if let Content::Alert(a) = &p.record.content {
568 if a.alert_description == AlertDescription::CloseNotify {
569 closed = true;
570 }
571 }*/
572
573 let raw_packet =
574 DTLSConn::process_packet(local_sequence_number, cipher_suite, p).await?;
575 raw_packets.push(raw_packet);
576 }
577 }
578
579 if !raw_packets.is_empty() {
580 let compacted_raw_packets =
581 compact_raw_packets(&raw_packets, maximum_transmission_unit);
582
583 for compacted_raw_packets in &compacted_raw_packets {
584 next_conn.send(compacted_raw_packets).await?;
585 }
586 }
587
588 Ok(())
589 }
590
process_packet( local_sequence_number: &Arc<Mutex<Vec<u64>>>, cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, p: &mut Packet, ) -> Result<Vec<u8>>591 async fn process_packet(
592 local_sequence_number: &Arc<Mutex<Vec<u64>>>,
593 cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
594 p: &mut Packet,
595 ) -> Result<Vec<u8>> {
596 let epoch = p.record.record_layer_header.epoch as usize;
597 let seq = {
598 let mut lsn = local_sequence_number.lock().await;
599 while lsn.len() <= epoch {
600 lsn.push(0);
601 }
602
603 lsn[epoch] += 1;
604 lsn[epoch] - 1
605 };
606 //trace!("{}: seq = {}", srv_cli_str(is_client), seq);
607
608 if seq > MAX_SEQUENCE_NUMBER {
609 // RFC 6347 Section 4.1.0
610 // The implementation must either abandon an association or rehandshake
611 // prior to allowing the sequence number to wrap.
612 return Err(Error::ErrSequenceNumberOverflow);
613 }
614 p.record.record_layer_header.sequence_number = seq;
615
616 let mut raw_packet = vec![];
617 {
618 let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_packet.as_mut());
619 p.record.marshal(&mut writer)?;
620 }
621
622 if p.should_encrypt {
623 let cipher_suite = cipher_suite.lock().await;
624 if let Some(cipher_suite) = &*cipher_suite {
625 raw_packet = cipher_suite.encrypt(&p.record.record_layer_header, &raw_packet)?;
626 }
627 }
628
629 Ok(raw_packet)
630 }
631
process_handshake_packet( local_sequence_number: &Arc<Mutex<Vec<u64>>>, cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, maximum_transmission_unit: usize, p: &Packet, h: &Handshake, ) -> Result<Vec<Vec<u8>>>632 async fn process_handshake_packet(
633 local_sequence_number: &Arc<Mutex<Vec<u64>>>,
634 cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
635 maximum_transmission_unit: usize,
636 p: &Packet,
637 h: &Handshake,
638 ) -> Result<Vec<Vec<u8>>> {
639 let mut raw_packets = vec![];
640
641 let handshake_fragments = DTLSConn::fragment_handshake(maximum_transmission_unit, h)?;
642
643 let epoch = p.record.record_layer_header.epoch as usize;
644
645 let mut lsn = local_sequence_number.lock().await;
646 while lsn.len() <= epoch {
647 lsn.push(0);
648 }
649
650 for handshake_fragment in &handshake_fragments {
651 let seq = {
652 lsn[epoch] += 1;
653 lsn[epoch] - 1
654 };
655 //trace!("seq = {}", seq);
656 if seq > MAX_SEQUENCE_NUMBER {
657 return Err(Error::ErrSequenceNumberOverflow);
658 }
659
660 let record_layer_header = RecordLayerHeader {
661 protocol_version: p.record.record_layer_header.protocol_version,
662 content_type: p.record.record_layer_header.content_type,
663 content_len: handshake_fragment.len() as u16,
664 epoch: p.record.record_layer_header.epoch,
665 sequence_number: seq,
666 };
667
668 let mut record_layer_header_bytes = vec![];
669 {
670 let mut writer = BufWriter::<&mut Vec<u8>>::new(record_layer_header_bytes.as_mut());
671 record_layer_header.marshal(&mut writer)?;
672 }
673
674 //p.record.record_layer_header = record_layer_header;
675
676 let mut raw_packet = vec![];
677 raw_packet.extend_from_slice(&record_layer_header_bytes);
678 raw_packet.extend_from_slice(handshake_fragment);
679 if p.should_encrypt {
680 let cipher_suite = cipher_suite.lock().await;
681 if let Some(cipher_suite) = &*cipher_suite {
682 raw_packet = cipher_suite.encrypt(&record_layer_header, &raw_packet)?;
683 }
684 }
685
686 raw_packets.push(raw_packet);
687 }
688
689 Ok(raw_packets)
690 }
691
fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>>692 fn fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>> {
693 let mut content = vec![];
694 {
695 let mut writer = BufWriter::<&mut Vec<u8>>::new(content.as_mut());
696 h.handshake_message.marshal(&mut writer)?;
697 }
698
699 let mut fragmented_handshakes = vec![];
700
701 let mut content_fragments = split_bytes(&content, maximum_transmission_unit);
702 if content_fragments.is_empty() {
703 content_fragments = vec![vec![]];
704 }
705
706 let mut offset = 0;
707 for content_fragment in &content_fragments {
708 let content_fragment_len = content_fragment.len();
709
710 let handshake_header_fragment = HandshakeHeader {
711 handshake_type: h.handshake_header.handshake_type,
712 length: h.handshake_header.length,
713 message_sequence: h.handshake_header.message_sequence,
714 fragment_offset: offset as u32,
715 fragment_length: content_fragment_len as u32,
716 };
717
718 offset += content_fragment_len;
719
720 let mut handshake_header_fragment_raw = vec![];
721 {
722 let mut writer =
723 BufWriter::<&mut Vec<u8>>::new(handshake_header_fragment_raw.as_mut());
724 handshake_header_fragment.marshal(&mut writer)?;
725 }
726
727 let mut fragmented_handshake = vec![];
728 fragmented_handshake.extend_from_slice(&handshake_header_fragment_raw);
729 fragmented_handshake.extend_from_slice(content_fragment);
730
731 fragmented_handshakes.push(fragmented_handshake);
732 }
733
734 Ok(fragmented_handshakes)
735 }
736
set_handshake_completed_successfully(&mut self)737 pub(crate) fn set_handshake_completed_successfully(&mut self) {
738 self.handshake_completed_successfully
739 .store(true, Ordering::SeqCst);
740 }
741
is_handshake_completed_successfully(&self) -> bool742 pub(crate) fn is_handshake_completed_successfully(&self) -> bool {
743 self.handshake_completed_successfully.load(Ordering::SeqCst)
744 }
745
read_and_buffer( ctx: &mut ConnReaderContext, next_conn: &Arc<dyn util::Conn + Send + Sync>, handle_queue_rx: &mut mpsc::Receiver<mpsc::Sender<()>>, buf: &mut [u8], local_epoch: &Arc<AtomicU16>, handshake_completed_successfully: &Arc<AtomicBool>, ) -> Result<()>746 async fn read_and_buffer(
747 ctx: &mut ConnReaderContext,
748 next_conn: &Arc<dyn util::Conn + Send + Sync>,
749 handle_queue_rx: &mut mpsc::Receiver<mpsc::Sender<()>>,
750 buf: &mut [u8],
751 local_epoch: &Arc<AtomicU16>,
752 handshake_completed_successfully: &Arc<AtomicBool>,
753 ) -> Result<()> {
754 let n = next_conn.recv(buf).await?;
755 let pkts = unpack_datagram(&buf[..n])?;
756 let mut has_handshake = false;
757 for pkt in pkts {
758 let (hs, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, pkt, true).await;
759 if let Some(alert) = alert {
760 let alert_err = ctx
761 .packet_tx
762 .send((
763 vec![Packet {
764 record: RecordLayer::new(
765 PROTOCOL_VERSION1_2,
766 local_epoch.load(Ordering::SeqCst),
767 Content::Alert(Alert {
768 alert_level: alert.alert_level,
769 alert_description: alert.alert_description,
770 }),
771 ),
772 should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
773 reset_local_sequence_number: false,
774 }],
775 None,
776 ))
777 .await;
778
779 if let Err(alert_err) = alert_err {
780 if err.is_none() {
781 err = Some(Error::Other(alert_err.to_string()));
782 }
783 }
784
785 if alert.alert_level == AlertLevel::Fatal
786 || alert.alert_description == AlertDescription::CloseNotify
787 {
788 return Err(Error::ErrAlertFatalOrClose);
789 }
790 }
791
792 if let Some(err) = err {
793 return Err(err);
794 }
795
796 if hs {
797 has_handshake = true
798 }
799 }
800
801 if has_handshake {
802 let (done_tx, mut done_rx) = mpsc::channel(1);
803
804 tokio::select! {
805 _ = ctx.handshake_tx.send(done_tx) => {
806 let mut wait_done_rx = true;
807 while wait_done_rx{
808 tokio::select!{
809 _ = done_rx.recv() => {
810 // If the other party may retransmit the flight,
811 // we should respond even if it not a new message.
812 wait_done_rx = false;
813 }
814 done = handle_queue_rx.recv() => {
815 //trace!("recv handle_queue: {} ", srv_cli_str(ctx.is_client));
816
817 let pkts = ctx.encrypted_packets.drain(..).collect();
818 DTLSConn::handle_queued_packets(ctx, local_epoch, handshake_completed_successfully, pkts).await?;
819
820 drop(done);
821 }
822 }
823 }
824 }
825 _ = ctx.handshake_done_rx.recv() => {}
826 }
827 }
828
829 Ok(())
830 }
831
handle_queued_packets( ctx: &mut ConnReaderContext, local_epoch: &Arc<AtomicU16>, handshake_completed_successfully: &Arc<AtomicBool>, pkts: Vec<Vec<u8>>, ) -> Result<()>832 async fn handle_queued_packets(
833 ctx: &mut ConnReaderContext,
834 local_epoch: &Arc<AtomicU16>,
835 handshake_completed_successfully: &Arc<AtomicBool>,
836 pkts: Vec<Vec<u8>>,
837 ) -> Result<()> {
838 for p in pkts {
839 let (_, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, p, false).await; // don't re-enqueue
840 if let Some(alert) = alert {
841 let alert_err = ctx
842 .packet_tx
843 .send((
844 vec![Packet {
845 record: RecordLayer::new(
846 PROTOCOL_VERSION1_2,
847 local_epoch.load(Ordering::SeqCst),
848 Content::Alert(Alert {
849 alert_level: alert.alert_level,
850 alert_description: alert.alert_description,
851 }),
852 ),
853 should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
854 reset_local_sequence_number: false,
855 }],
856 None,
857 ))
858 .await;
859
860 if let Err(alert_err) = alert_err {
861 if err.is_none() {
862 err = Some(Error::Other(alert_err.to_string()));
863 }
864 }
865 if alert.alert_level == AlertLevel::Fatal
866 || alert.alert_description == AlertDescription::CloseNotify
867 {
868 return Err(Error::ErrAlertFatalOrClose);
869 }
870 }
871
872 if let Some(err) = err {
873 return Err(err);
874 }
875 }
876
877 Ok(())
878 }
879
handle_incoming_packet( ctx: &mut ConnReaderContext, mut pkt: Vec<u8>, enqueue: bool, ) -> (bool, Option<Alert>, Option<Error>)880 async fn handle_incoming_packet(
881 ctx: &mut ConnReaderContext,
882 mut pkt: Vec<u8>,
883 enqueue: bool,
884 ) -> (bool, Option<Alert>, Option<Error>) {
885 let mut reader = BufReader::new(pkt.as_slice());
886 let h = match RecordLayerHeader::unmarshal(&mut reader) {
887 Ok(h) => h,
888 Err(err) => {
889 // Decode error must be silently discarded
890 // [RFC6347 Section-4.1.2.7]
891 debug!(
892 "{}: discarded broken packet: {}",
893 srv_cli_str(ctx.is_client),
894 err
895 );
896 return (false, None, None);
897 }
898 };
899
900 // Validate epoch
901 let epoch = ctx.remote_epoch.load(Ordering::SeqCst);
902 if h.epoch > epoch {
903 if h.epoch > epoch + 1 {
904 debug!(
905 "{}: discarded future packet (epoch: {}, seq: {})",
906 srv_cli_str(ctx.is_client),
907 h.epoch,
908 h.sequence_number,
909 );
910 return (false, None, None);
911 }
912 if enqueue {
913 debug!(
914 "{}: received packet of next epoch, queuing packet",
915 srv_cli_str(ctx.is_client)
916 );
917 ctx.encrypted_packets.push(pkt);
918 }
919 return (false, None, None);
920 }
921
922 // Anti-replay protection
923 while ctx.replay_detector.len() <= h.epoch as usize {
924 ctx.replay_detector
925 .push(Box::new(SlidingWindowDetector::new(
926 ctx.replay_protection_window,
927 MAX_SEQUENCE_NUMBER,
928 )));
929 }
930
931 let ok = ctx.replay_detector[h.epoch as usize].check(h.sequence_number);
932 if !ok {
933 debug!(
934 "{}: discarded duplicated packet (epoch: {}, seq: {})",
935 srv_cli_str(ctx.is_client),
936 h.epoch,
937 h.sequence_number,
938 );
939 return (false, None, None);
940 }
941
942 // Decrypt
943 if h.epoch != 0 {
944 let invalid_cipher_suite = {
945 let cipher_suite = ctx.cipher_suite.lock().await;
946 if cipher_suite.is_none() {
947 true
948 } else if let Some(cipher_suite) = &*cipher_suite {
949 !cipher_suite.is_initialized()
950 } else {
951 false
952 }
953 };
954 if invalid_cipher_suite {
955 if enqueue {
956 debug!(
957 "{}: handshake not finished, queuing packet",
958 srv_cli_str(ctx.is_client)
959 );
960 ctx.encrypted_packets.push(pkt);
961 }
962 return (false, None, None);
963 }
964
965 let cipher_suite = ctx.cipher_suite.lock().await;
966 if let Some(cipher_suite) = &*cipher_suite {
967 pkt = match cipher_suite.decrypt(&pkt) {
968 Ok(pkt) => pkt,
969 Err(err) => {
970 debug!("{}: decrypt failed: {}", srv_cli_str(ctx.is_client), err);
971 return (false, None, None);
972 }
973 };
974 }
975 }
976
977 let is_handshake = match ctx.fragment_buffer.push(&pkt) {
978 Ok(is_handshake) => is_handshake,
979 Err(err) => {
980 // Decode error must be silently discarded
981 // [RFC6347 Section-4.1.2.7]
982 debug!("{}: defragment failed: {}", srv_cli_str(ctx.is_client), err);
983 return (false, None, None);
984 }
985 };
986 if is_handshake {
987 ctx.replay_detector[h.epoch as usize].accept();
988 while let Ok((out, epoch)) = ctx.fragment_buffer.pop() {
989 //log::debug!("Extension Debug: out.len()={}", out.len());
990 let mut reader = BufReader::new(out.as_slice());
991 let raw_handshake = match Handshake::unmarshal(&mut reader) {
992 Ok(rh) => {
993 trace!(
994 "Recv [handshake:{}] -> {} (epoch: {}, seq: {})",
995 srv_cli_str(ctx.is_client),
996 rh.handshake_header.handshake_type.to_string(),
997 h.epoch,
998 rh.handshake_header.message_sequence
999 );
1000 rh
1001 }
1002 Err(err) => {
1003 debug!(
1004 "{}: handshake parse failed: {}",
1005 srv_cli_str(ctx.is_client),
1006 err
1007 );
1008 continue;
1009 }
1010 };
1011
1012 ctx.cache
1013 .push(
1014 out,
1015 epoch,
1016 raw_handshake.handshake_header.message_sequence,
1017 raw_handshake.handshake_header.handshake_type,
1018 !ctx.is_client,
1019 )
1020 .await;
1021 }
1022
1023 return (true, None, None);
1024 }
1025
1026 let mut reader = BufReader::new(pkt.as_slice());
1027 let r = match RecordLayer::unmarshal(&mut reader) {
1028 Ok(r) => r,
1029 Err(err) => {
1030 return (
1031 false,
1032 Some(Alert {
1033 alert_level: AlertLevel::Fatal,
1034 alert_description: AlertDescription::DecodeError,
1035 }),
1036 Some(err),
1037 );
1038 }
1039 };
1040
1041 match r.content {
1042 Content::Alert(mut a) => {
1043 trace!("{}: <- {}", srv_cli_str(ctx.is_client), a.to_string());
1044 if a.alert_description == AlertDescription::CloseNotify {
1045 // Respond with a close_notify [RFC5246 Section 7.2.1]
1046 a = Alert {
1047 alert_level: AlertLevel::Warning,
1048 alert_description: AlertDescription::CloseNotify,
1049 };
1050 }
1051 ctx.replay_detector[h.epoch as usize].accept();
1052 return (
1053 false,
1054 Some(a),
1055 Some(Error::Other(format!("Error of Alert {a}"))),
1056 );
1057 }
1058 Content::ChangeCipherSpec(_) => {
1059 let invalid_cipher_suite = {
1060 let cipher_suite = ctx.cipher_suite.lock().await;
1061 if cipher_suite.is_none() {
1062 true
1063 } else if let Some(cipher_suite) = &*cipher_suite {
1064 !cipher_suite.is_initialized()
1065 } else {
1066 false
1067 }
1068 };
1069
1070 if invalid_cipher_suite {
1071 if enqueue {
1072 debug!(
1073 "{}: CipherSuite not initialized, queuing packet",
1074 srv_cli_str(ctx.is_client)
1075 );
1076 ctx.encrypted_packets.push(pkt);
1077 }
1078 return (false, None, None);
1079 }
1080
1081 let new_remote_epoch = h.epoch + 1;
1082 trace!(
1083 "{}: <- ChangeCipherSpec (epoch: {})",
1084 srv_cli_str(ctx.is_client),
1085 new_remote_epoch
1086 );
1087
1088 if epoch + 1 == new_remote_epoch {
1089 ctx.remote_epoch.store(new_remote_epoch, Ordering::SeqCst);
1090 ctx.replay_detector[h.epoch as usize].accept();
1091 }
1092 }
1093 Content::ApplicationData(a) => {
1094 if h.epoch == 0 {
1095 return (
1096 false,
1097 Some(Alert {
1098 alert_level: AlertLevel::Fatal,
1099 alert_description: AlertDescription::UnexpectedMessage,
1100 }),
1101 Some(Error::ErrApplicationDataEpochZero),
1102 );
1103 }
1104
1105 ctx.replay_detector[h.epoch as usize].accept();
1106
1107 let _ = ctx.decrypted_tx.send(Ok(a.data)).await;
1108 //TODO
1109 /*select {
1110 case self.decrypted < - content.data:
1111 case < -c.closed.Done():
1112 }*/
1113 }
1114 _ => {
1115 return (
1116 false,
1117 Some(Alert {
1118 alert_level: AlertLevel::Fatal,
1119 alert_description: AlertDescription::UnexpectedMessage,
1120 }),
1121 Some(Error::ErrUnhandledContextType),
1122 );
1123 }
1124 };
1125
1126 (false, None, None)
1127 }
1128
is_connection_closed(&self) -> bool1129 fn is_connection_closed(&self) -> bool {
1130 self.closed.load(Ordering::SeqCst)
1131 }
1132
set_local_epoch(&mut self, epoch: u16)1133 pub(crate) fn set_local_epoch(&mut self, epoch: u16) {
1134 self.state.local_epoch.store(epoch, Ordering::SeqCst);
1135 }
1136
get_local_epoch(&self) -> u161137 pub(crate) fn get_local_epoch(&self) -> u16 {
1138 self.state.local_epoch.load(Ordering::SeqCst)
1139 }
1140 }
1141
compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<Vec<u8>>1142 fn compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<Vec<u8>> {
1143 let mut combined_raw_packets = vec![];
1144 let mut current_combined_raw_packet = vec![];
1145
1146 for raw_packet in raw_packets {
1147 if !current_combined_raw_packet.is_empty()
1148 && current_combined_raw_packet.len() + raw_packet.len() >= maximum_transmission_unit
1149 {
1150 combined_raw_packets.push(current_combined_raw_packet);
1151 current_combined_raw_packet = vec![];
1152 }
1153 current_combined_raw_packet.extend_from_slice(raw_packet);
1154 }
1155
1156 combined_raw_packets.push(current_combined_raw_packet);
1157
1158 combined_raw_packets
1159 }
1160
split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>>1161 fn split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>> {
1162 let mut splits = vec![];
1163 let num_bytes = bytes.len();
1164 for i in (0..num_bytes).step_by(split_len) {
1165 let mut j = i + split_len;
1166 if j > num_bytes {
1167 j = num_bytes;
1168 }
1169
1170 splits.push(bytes[i..j].to_vec());
1171 }
1172
1173 splits
1174 }
1175