xref: /webrtc/dtls/src/fragment_buffer/mod.rs (revision 630c46fe)
1 #[cfg(test)]
2 mod fragment_buffer_test;
3 
4 use crate::content::*;
5 use crate::error::*;
6 use crate::handshake::handshake_header::*;
7 use crate::record_layer::record_layer_header::*;
8 
9 use std::collections::HashMap;
10 use std::io::{BufWriter, Cursor};
11 
12 // 2 mb max buffer size
13 const FRAGMENT_BUFFER_MAX_SIZE: usize = 2_000_000;
14 
15 pub(crate) struct Fragment {
16     record_layer_header: RecordLayerHeader,
17     handshake_header: HandshakeHeader,
18     data: Vec<u8>,
19 }
20 
21 pub(crate) struct FragmentBuffer {
22     // map of MessageSequenceNumbers that hold slices of fragments
23     cache: HashMap<u16, Vec<Fragment>>,
24 
25     current_message_sequence_number: u16,
26 }
27 
28 impl FragmentBuffer {
new() -> Self29     pub fn new() -> Self {
30         FragmentBuffer {
31             cache: HashMap::new(),
32             current_message_sequence_number: 0,
33         }
34     }
35 
36     // Attempts to push a DTLS packet to the FragmentBuffer
37     // when it returns true it means the FragmentBuffer has inserted and the buffer shouldn't be handled
38     // when an error returns it is fatal, and the DTLS connection should be stopped
push(&mut self, mut buf: &[u8]) -> Result<bool>39     pub fn push(&mut self, mut buf: &[u8]) -> Result<bool> {
40         let current_size = self.size();
41         if current_size + buf.len() >= FRAGMENT_BUFFER_MAX_SIZE {
42             return Err(Error::ErrFragmentBufferOverflow {
43                 new_size: current_size + buf.len(),
44                 max_size: FRAGMENT_BUFFER_MAX_SIZE,
45             });
46         }
47 
48         let mut reader = Cursor::new(buf);
49         let record_layer_header = RecordLayerHeader::unmarshal(&mut reader)?;
50 
51         // Fragment isn't a handshake, we don't need to handle it
52         if record_layer_header.content_type != ContentType::Handshake {
53             return Ok(false);
54         }
55 
56         buf = &buf[RECORD_LAYER_HEADER_SIZE..];
57         while !buf.is_empty() {
58             let mut reader = Cursor::new(buf);
59             let handshake_header = HandshakeHeader::unmarshal(&mut reader)?;
60 
61             self.cache
62                 .entry(handshake_header.message_sequence)
63                 .or_insert_with(Vec::new);
64 
65             // end index should be the length of handshake header but if the handshake
66             // was fragmented, we should keep them all
67             let mut end = HANDSHAKE_HEADER_LENGTH + handshake_header.length as usize;
68             if end > buf.len() {
69                 end = buf.len();
70             }
71 
72             // Discard all headers, when rebuilding the packet we will re-build
73             let data = buf[HANDSHAKE_HEADER_LENGTH..end].to_vec();
74 
75             if let Some(x) = self.cache.get_mut(&handshake_header.message_sequence) {
76                 x.push(Fragment {
77                     record_layer_header,
78                     handshake_header,
79                     data,
80                 });
81             }
82             buf = &buf[end..];
83         }
84 
85         Ok(true)
86     }
87 
pop(&mut self) -> Result<(Vec<u8>, u16)>88     pub fn pop(&mut self) -> Result<(Vec<u8>, u16)> {
89         let seq_num = self.current_message_sequence_number;
90         if !self.cache.contains_key(&seq_num) {
91             return Err(Error::ErrEmptyFragment);
92         }
93 
94         let (content, epoch) = if let Some(frags) = self.cache.get_mut(&seq_num) {
95             let mut raw_message = vec![];
96             // Recursively collect up
97             if !append_message(0, frags, &mut raw_message) {
98                 return Err(Error::ErrEmptyFragment);
99             }
100 
101             let mut first_header = frags[0].handshake_header;
102             first_header.fragment_offset = 0;
103             first_header.fragment_length = first_header.length;
104 
105             let mut raw_header = vec![];
106             {
107                 let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_header.as_mut());
108                 if first_header.marshal(&mut writer).is_err() {
109                     return Err(Error::ErrEmptyFragment);
110                 }
111             }
112 
113             let message_epoch = frags[0].record_layer_header.epoch;
114 
115             raw_header.extend_from_slice(&raw_message);
116 
117             (raw_header, message_epoch)
118         } else {
119             return Err(Error::ErrEmptyFragment);
120         };
121 
122         self.cache.remove(&seq_num);
123         self.current_message_sequence_number += 1;
124 
125         Ok((content, epoch))
126     }
127 
size(&self) -> usize128     fn size(&self) -> usize {
129         self.cache
130             .values()
131             .map(|fragment| fragment.iter().map(|f| f.data.len()).sum::<usize>())
132             .sum()
133     }
134 }
135 
append_message(target_offset: u32, frags: &[Fragment], raw_message: &mut Vec<u8>) -> bool136 fn append_message(target_offset: u32, frags: &[Fragment], raw_message: &mut Vec<u8>) -> bool {
137     for f in frags {
138         if f.handshake_header.fragment_offset == target_offset {
139             let fragment_end =
140                 f.handshake_header.fragment_offset + f.handshake_header.fragment_length;
141 
142             // NB: Order here is imporant, the `f.handshake_header.fragment_length != 0`
143             // MUST come before the recursive call.
144             if fragment_end != f.handshake_header.length
145                 && f.handshake_header.fragment_length != 0
146                 && !append_message(fragment_end, frags, raw_message)
147             {
148                 return false;
149             }
150 
151             let mut message = vec![];
152             message.extend_from_slice(&f.data);
153             message.extend_from_slice(raw_message);
154             *raw_message = message;
155             return true;
156         }
157     }
158 
159     false
160 }
161