xref: /webrtc/dtls/src/crypto/padding.rs (revision ffe74184)
1 use block_modes::block_padding::{PadError, Padding, UnpadError};
2 
3 pub enum DtlsPadding {}
4 /// Reference: RFC5246, 6.2.3.2
5 impl Padding for DtlsPadding {
pad_block(block: &mut [u8], pos: usize) -> Result<(), PadError>6     fn pad_block(block: &mut [u8], pos: usize) -> Result<(), PadError> {
7         if pos == block.len() {
8             return Err(PadError);
9         }
10 
11         let padding_length = block.len() - pos - 1;
12         if padding_length > 255 {
13             return Err(PadError);
14         }
15 
16         set(&mut block[pos..], padding_length as u8);
17 
18         Ok(())
19     }
20 
unpad(data: &[u8]) -> Result<&[u8], UnpadError>21     fn unpad(data: &[u8]) -> Result<&[u8], UnpadError> {
22         let padding_length = data.last().copied().unwrap_or(1) as usize;
23         if padding_length + 1 > data.len() {
24             return Err(UnpadError);
25         }
26 
27         let padding_begin = data.len() - padding_length - 1;
28 
29         if data[padding_begin..data.len() - 1]
30             .iter()
31             .any(|&byte| byte as usize != padding_length)
32         {
33             return Err(UnpadError);
34         }
35 
36         Ok(&data[0..padding_begin])
37     }
38 }
39 
40 /// Sets all bytes in `dst` equal to `value`
41 #[inline(always)]
set(dst: &mut [u8], value: u8)42 fn set(dst: &mut [u8], value: u8) {
43     // SAFETY: we overwrite valid memory behind `dst`
44     // note: loop is not used here because it produces
45     // unnecessary branch which tests for zero-length slices
46     unsafe {
47         core::ptr::write_bytes(dst.as_mut_ptr(), value, dst.len());
48     }
49 }
50 
51 #[cfg(test)]
52 pub mod tests {
53     use rand::Rng;
54 
55     use super::*;
56 
57     #[test]
padding_length_is_amount_of_bytes_excluding_the_padding_length_itself( ) -> Result<(), PadError>58     fn padding_length_is_amount_of_bytes_excluding_the_padding_length_itself(
59     ) -> Result<(), PadError> {
60         for original_length in 0..128 {
61             for padding_length in 0..(256 - original_length) {
62                 let mut block = vec![0; original_length + padding_length + 1];
63                 rand::thread_rng().fill(&mut block[0..original_length]);
64                 let original = block[0..original_length].to_vec();
65                 DtlsPadding::pad_block(&mut block, original_length)?;
66 
67                 for byte in block[original_length..].iter() {
68                     assert_eq!(*byte as usize, padding_length);
69                 }
70                 assert_eq!(block[0..original_length], original);
71             }
72         }
73 
74         Ok(())
75     }
76 
77     #[test]
full_block_is_padding_error()78     fn full_block_is_padding_error() {
79         for original_length in 0..256 {
80             let mut block = vec![0; original_length];
81             let r = DtlsPadding::pad_block(&mut block, original_length);
82             assert!(r.is_err());
83         }
84     }
85 
86     #[test]
padding_length_bigger_than_255_is_a_pad_error()87     fn padding_length_bigger_than_255_is_a_pad_error() {
88         let padding_length = 256;
89         for original_length in 0..128 {
90             let mut block = vec![0; original_length + padding_length + 1];
91             let r = DtlsPadding::pad_block(&mut block, original_length);
92 
93             assert!(r.is_err());
94         }
95     }
96 
97     #[test]
empty_block_is_unpadding_error()98     fn empty_block_is_unpadding_error() {
99         let r = DtlsPadding::unpad(&[]);
100         assert!(r.is_err());
101     }
102 
103     #[test]
padding_too_big_for_block_is_unpadding_error()104     fn padding_too_big_for_block_is_unpadding_error() {
105         let r = DtlsPadding::unpad(&[1]);
106         assert!(r.is_err());
107     }
108 
109     #[test]
one_of_the_padding_bytes_with_value_different_than_padding_length_is_unpadding_error()110     fn one_of_the_padding_bytes_with_value_different_than_padding_length_is_unpadding_error() {
111         for padding_length in 0..16 {
112             for invalid_byte in 0..padding_length {
113                 let mut block = vec![0; padding_length + 1];
114                 DtlsPadding::pad_block(&mut block, 0).unwrap();
115 
116                 assert_eq!(DtlsPadding::unpad(&block).ok(), Some(&[][..]));
117                 block[invalid_byte] = (padding_length - 1) as u8;
118                 let r = DtlsPadding::unpad(&block);
119                 assert!(r.is_err());
120             }
121         }
122     }
123 }
124