xref: /webrtc/rtp/src/header.rs (revision 8b68b472)
1 use crate::error::Error;
2 use util::marshal::{Marshal, MarshalSize, Unmarshal};
3 
4 use bytes::{Buf, BufMut, Bytes};
5 
6 pub const HEADER_LENGTH: usize = 4;
7 pub const VERSION_SHIFT: u8 = 6;
8 pub const VERSION_MASK: u8 = 0x3;
9 pub const PADDING_SHIFT: u8 = 5;
10 pub const PADDING_MASK: u8 = 0x1;
11 pub const EXTENSION_SHIFT: u8 = 4;
12 pub const EXTENSION_MASK: u8 = 0x1;
13 pub const EXTENSION_PROFILE_ONE_BYTE: u16 = 0xBEDE;
14 pub const EXTENSION_PROFILE_TWO_BYTE: u16 = 0x1000;
15 pub const EXTENSION_ID_RESERVED: u8 = 0xF;
16 pub const CC_MASK: u8 = 0xF;
17 pub const MARKER_SHIFT: u8 = 7;
18 pub const MARKER_MASK: u8 = 0x1;
19 pub const PT_MASK: u8 = 0x7F;
20 pub const SEQ_NUM_OFFSET: usize = 2;
21 pub const SEQ_NUM_LENGTH: usize = 2;
22 pub const TIMESTAMP_OFFSET: usize = 4;
23 pub const TIMESTAMP_LENGTH: usize = 4;
24 pub const SSRC_OFFSET: usize = 8;
25 pub const SSRC_LENGTH: usize = 4;
26 pub const CSRC_OFFSET: usize = 12;
27 pub const CSRC_LENGTH: usize = 4;
28 
29 #[derive(Debug, Eq, PartialEq, Default, Clone)]
30 pub struct Extension {
31     pub id: u8,
32     pub payload: Bytes,
33 }
34 
35 /// Header represents an RTP packet header
36 /// NOTE: PayloadOffset is populated by Marshal/Unmarshal and should not be modified
37 #[derive(Debug, Eq, PartialEq, Default, Clone)]
38 pub struct Header {
39     pub version: u8,
40     pub padding: bool,
41     pub extension: bool,
42     pub marker: bool,
43     pub payload_type: u8,
44     pub sequence_number: u16,
45     pub timestamp: u32,
46     pub ssrc: u32,
47     pub csrc: Vec<u32>,
48     pub extension_profile: u16,
49     pub extensions: Vec<Extension>,
50 }
51 
52 impl Unmarshal for Header {
53     /// Unmarshal parses the passed byte slice and stores the result in the Header this method is called upon
54     fn unmarshal<B>(raw_packet: &mut B) -> Result<Self, util::Error>
55     where
56         Self: Sized,
57         B: Buf,
58     {
59         let raw_packet_len = raw_packet.remaining();
60         if raw_packet_len < HEADER_LENGTH {
61             return Err(Error::ErrHeaderSizeInsufficient.into());
62         }
63         /*
64          *  0                   1                   2                   3
65          *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
66          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
67          * |V=2|P|X|  CC   |M|     PT      |       sequence number         |
68          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
69          * |                           timestamp                           |
70          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
71          * |           synchronization source (SSRC) identifier            |
72          * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
73          * |            contributing source (CSRC) identifiers             |
74          * |                             ....                              |
75          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
76          */
77         let b0 = raw_packet.get_u8();
78         let version = b0 >> VERSION_SHIFT & VERSION_MASK;
79         let padding = (b0 >> PADDING_SHIFT & PADDING_MASK) > 0;
80         let extension = (b0 >> EXTENSION_SHIFT & EXTENSION_MASK) > 0;
81         let cc = (b0 & CC_MASK) as usize;
82 
83         let mut curr_offset = CSRC_OFFSET + (cc * CSRC_LENGTH);
84         if raw_packet_len < curr_offset {
85             return Err(Error::ErrHeaderSizeInsufficient.into());
86         }
87 
88         let b1 = raw_packet.get_u8();
89         let marker = (b1 >> MARKER_SHIFT & MARKER_MASK) > 0;
90         let payload_type = b1 & PT_MASK;
91 
92         let sequence_number = raw_packet.get_u16();
93         let timestamp = raw_packet.get_u32();
94         let ssrc = raw_packet.get_u32();
95 
96         let mut csrc = Vec::with_capacity(cc);
97         for _ in 0..cc {
98             csrc.push(raw_packet.get_u32());
99         }
100 
101         let (extension_profile, extensions) = if extension {
102             let expected = curr_offset + 4;
103             if raw_packet_len < expected {
104                 return Err(Error::ErrHeaderSizeInsufficientForExtension.into());
105             }
106             let extension_profile = raw_packet.get_u16();
107             curr_offset += 2;
108             let extension_length = raw_packet.get_u16() as usize * 4;
109             curr_offset += 2;
110 
111             let expected = curr_offset + extension_length;
112             if raw_packet_len < expected {
113                 return Err(Error::ErrHeaderSizeInsufficientForExtension.into());
114             }
115 
116             let mut extensions = vec![];
117             match extension_profile {
118                 // RFC 8285 RTP One Byte Header Extension
119                 EXTENSION_PROFILE_ONE_BYTE => {
120                     let end = curr_offset + extension_length;
121                     while curr_offset < end {
122                         let b = raw_packet.get_u8();
123                         if b == 0x00 {
124                             // padding
125                             curr_offset += 1;
126                             continue;
127                         }
128 
129                         let extid = b >> 4;
130                         let len = ((b & (0xFF ^ 0xF0)) + 1) as usize;
131                         curr_offset += 1;
132 
133                         if extid == EXTENSION_ID_RESERVED {
134                             break;
135                         }
136 
137                         extensions.push(Extension {
138                             id: extid,
139                             payload: raw_packet.copy_to_bytes(len),
140                         });
141                         curr_offset += len;
142                     }
143                 }
144                 // RFC 8285 RTP Two Byte Header Extension
145                 EXTENSION_PROFILE_TWO_BYTE => {
146                     let end = curr_offset + extension_length;
147                     while curr_offset < end {
148                         let b = raw_packet.get_u8();
149                         if b == 0x00 {
150                             // padding
151                             curr_offset += 1;
152                             continue;
153                         }
154 
155                         let extid = b;
156                         curr_offset += 1;
157 
158                         let len = raw_packet.get_u8() as usize;
159                         curr_offset += 1;
160 
161                         extensions.push(Extension {
162                             id: extid,
163                             payload: raw_packet.copy_to_bytes(len),
164                         });
165                         curr_offset += len;
166                     }
167                 }
168                 // RFC3550 Extension
169                 _ => {
170                     if raw_packet_len < curr_offset + extension_length {
171                         return Err(Error::ErrHeaderSizeInsufficientForExtension.into());
172                     }
173                     extensions.push(Extension {
174                         id: 0,
175                         payload: raw_packet.copy_to_bytes(extension_length),
176                     });
177                 }
178             };
179 
180             (extension_profile, extensions)
181         } else {
182             (0, vec![])
183         };
184 
185         Ok(Header {
186             version,
187             padding,
188             extension,
189             marker,
190             payload_type,
191             sequence_number,
192             timestamp,
193             ssrc,
194             csrc,
195             extension_profile,
196             extensions,
197         })
198     }
199 }
200 
201 impl MarshalSize for Header {
202     /// MarshalSize returns the size of the packet once marshaled.
203     fn marshal_size(&self) -> usize {
204         let mut head_size = 12 + (self.csrc.len() * CSRC_LENGTH);
205         if self.extension {
206             let extension_payload_len = self.get_extension_payload_len();
207             let extension_payload_size = (extension_payload_len + 3) / 4;
208             head_size += 4 + extension_payload_size * 4;
209         }
210         head_size
211     }
212 }
213 
214 impl Marshal for Header {
215     /// Marshal serializes the header and writes to the buffer.
216     fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize, util::Error> {
217         /*
218          *  0                   1                   2                   3
219          *  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
220          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
221          * |V=2|P|X|  CC   |M|     PT      |       sequence number         |
222          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
223          * |                           timestamp                           |
224          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
225          * |           synchronization source (SSRC) identifier            |
226          * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
227          * |            contributing source (CSRC) identifiers             |
228          * |                             ....                              |
229          * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
230          */
231         let remaining_before = buf.remaining_mut();
232         if remaining_before < self.marshal_size() {
233             return Err(Error::ErrBufferTooSmall.into());
234         }
235 
236         // The first byte contains the version, padding bit, extension bit, and csrc size
237         let mut b0 = (self.version << VERSION_SHIFT) | self.csrc.len() as u8;
238         if self.padding {
239             b0 |= 1 << PADDING_SHIFT;
240         }
241 
242         if self.extension {
243             b0 |= 1 << EXTENSION_SHIFT;
244         }
245         buf.put_u8(b0);
246 
247         // The second byte contains the marker bit and payload type.
248         let mut b1 = self.payload_type;
249         if self.marker {
250             b1 |= 1 << MARKER_SHIFT;
251         }
252         buf.put_u8(b1);
253 
254         buf.put_u16(self.sequence_number);
255         buf.put_u32(self.timestamp);
256         buf.put_u32(self.ssrc);
257 
258         for csrc in &self.csrc {
259             buf.put_u32(*csrc);
260         }
261 
262         if self.extension {
263             buf.put_u16(self.extension_profile);
264 
265             // calculate extensions size and round to 4 bytes boundaries
266             let extension_payload_len = self.get_extension_payload_len();
267             if self.extension_profile != EXTENSION_PROFILE_ONE_BYTE
268                 && self.extension_profile != EXTENSION_PROFILE_TWO_BYTE
269                 && extension_payload_len % 4 != 0
270             {
271                 //the payload must be in 32-bit words.
272                 return Err(Error::HeaderExtensionPayloadNot32BitWords.into());
273             }
274             let extension_payload_size = (extension_payload_len as u16 + 3) / 4;
275             buf.put_u16(extension_payload_size);
276 
277             match self.extension_profile {
278                 // RFC 8285 RTP One Byte Header Extension
279                 EXTENSION_PROFILE_ONE_BYTE => {
280                     for extension in &self.extensions {
281                         buf.put_u8((extension.id << 4) | (extension.payload.len() as u8 - 1));
282                         buf.put(&*extension.payload);
283                     }
284                 }
285                 // RFC 8285 RTP Two Byte Header Extension
286                 EXTENSION_PROFILE_TWO_BYTE => {
287                     for extension in &self.extensions {
288                         buf.put_u8(extension.id);
289                         buf.put_u8(extension.payload.len() as u8);
290                         buf.put(&*extension.payload);
291                     }
292                 }
293                 // RFC3550 Extension
294                 _ => {
295                     if self.extensions.len() != 1 {
296                         return Err(Error::ErrRfc3550headerIdrange.into());
297                     }
298 
299                     if let Some(extension) = self.extensions.first() {
300                         let ext_len = extension.payload.len();
301                         if ext_len % 4 != 0 {
302                             return Err(Error::HeaderExtensionPayloadNot32BitWords.into());
303                         }
304                         buf.put(&*extension.payload);
305                     }
306                 }
307             };
308 
309             // add padding to reach 4 bytes boundaries
310             for _ in extension_payload_len..extension_payload_size as usize * 4 {
311                 buf.put_u8(0);
312             }
313         }
314 
315         let remaining_after = buf.remaining_mut();
316         Ok(remaining_before - remaining_after)
317     }
318 }
319 
320 impl Header {
321     pub fn get_extension_payload_len(&self) -> usize {
322         let payload_len: usize = self
323             .extensions
324             .iter()
325             .map(|extension| extension.payload.len())
326             .sum();
327 
328         let profile_len = self.extensions.len()
329             * match self.extension_profile {
330                 EXTENSION_PROFILE_ONE_BYTE => 1,
331                 EXTENSION_PROFILE_TWO_BYTE => 2,
332                 _ => 0,
333             };
334 
335         payload_len + profile_len
336     }
337 
338     /// SetExtension sets an RTP header extension
339     pub fn set_extension(&mut self, id: u8, payload: Bytes) -> Result<(), Error> {
340         if self.extension {
341             match self.extension_profile {
342                 EXTENSION_PROFILE_ONE_BYTE => {
343                     if !(1..=14).contains(&id) {
344                         return Err(Error::ErrRfc8285oneByteHeaderIdrange);
345                     }
346                     if payload.len() > 16 {
347                         return Err(Error::ErrRfc8285oneByteHeaderSize);
348                     }
349                 }
350                 EXTENSION_PROFILE_TWO_BYTE => {
351                     if id < 1 {
352                         return Err(Error::ErrRfc8285twoByteHeaderIdrange);
353                     }
354                     if payload.len() > 255 {
355                         return Err(Error::ErrRfc8285twoByteHeaderSize);
356                     }
357                 }
358                 _ => {
359                     if id != 0 {
360                         return Err(Error::ErrRfc3550headerIdrange);
361                     }
362                 }
363             };
364 
365             // Update existing if it exists else add new extension
366             if let Some(extension) = self
367                 .extensions
368                 .iter_mut()
369                 .find(|extension| extension.id == id)
370             {
371                 extension.payload = payload;
372             } else {
373                 self.extensions.push(Extension { id, payload });
374             }
375         } else {
376             // No existing header extensions
377             self.extension = true;
378 
379             self.extension_profile = match payload.len() {
380                 0..=16 => EXTENSION_PROFILE_ONE_BYTE,
381                 17..=255 => EXTENSION_PROFILE_TWO_BYTE,
382                 _ => self.extension_profile,
383             };
384 
385             self.extensions.push(Extension { id, payload });
386         }
387         Ok(())
388     }
389 
390     /// returns an extension id array
391     pub fn get_extension_ids(&self) -> Vec<u8> {
392         if self.extension {
393             self.extensions.iter().map(|e| e.id).collect()
394         } else {
395             vec![]
396         }
397     }
398 
399     /// returns an RTP header extension
400     pub fn get_extension(&self, id: u8) -> Option<Bytes> {
401         if self.extension {
402             self.extensions
403                 .iter()
404                 .find(|extension| extension.id == id)
405                 .map(|extension| extension.payload.clone())
406         } else {
407             None
408         }
409     }
410 
411     /// Removes an RTP Header extension
412     pub fn del_extension(&mut self, id: u8) -> Result<(), Error> {
413         if self.extension {
414             if let Some(index) = self
415                 .extensions
416                 .iter()
417                 .position(|extension| extension.id == id)
418             {
419                 self.extensions.remove(index);
420                 Ok(())
421             } else {
422                 Err(Error::ErrHeaderExtensionNotFound)
423             }
424         } else {
425             Err(Error::ErrHeaderExtensionsNotEnabled)
426         }
427     }
428 }
429