xref: /webrtc/ice/src/udp_mux/socket_addr_ext.rs (revision ffe74184)
1 use std::array::TryFromSliceError;
2 use std::convert::TryInto;
3 use std::net::SocketAddr;
4 
5 use util::Error;
6 
7 pub(super) trait SocketAddrExt {
8     ///Encode a representation of `self` into the buffer and return the length of this encoded
9     ///version.
10     ///
11     /// The buffer needs to be at least 27 bytes in length.
encode(&self, buffer: &mut [u8]) -> Result<usize, Error>12     fn encode(&self, buffer: &mut [u8]) -> Result<usize, Error>;
13 
14     /// Decode a `SocketAddr` from a buffer. The encoding should have previously been done with
15     /// [`SocketAddrExt::encode`].
decode(buffer: &[u8]) -> Result<SocketAddr, Error>16     fn decode(buffer: &[u8]) -> Result<SocketAddr, Error>;
17 }
18 
19 const IPV4_MARKER: u8 = 4;
20 const IPV4_ADDRESS_SIZE: usize = 7;
21 const IPV6_MARKER: u8 = 6;
22 const IPV6_ADDRESS_SIZE: usize = 27;
23 
24 pub(super) const MAX_ADDR_SIZE: usize = IPV6_ADDRESS_SIZE;
25 
26 impl SocketAddrExt for SocketAddr {
encode(&self, buffer: &mut [u8]) -> Result<usize, Error>27     fn encode(&self, buffer: &mut [u8]) -> Result<usize, Error> {
28         use std::net::SocketAddr::{V4, V6};
29 
30         if buffer.len() < MAX_ADDR_SIZE {
31             return Err(Error::ErrBufferShort);
32         }
33 
34         match self {
35             V4(addr) => {
36                 let marker = IPV4_MARKER;
37                 let ip: [u8; 4] = addr.ip().octets();
38                 let port: u16 = addr.port();
39 
40                 buffer[0] = marker;
41                 buffer[1..5].copy_from_slice(&ip);
42                 buffer[5..7].copy_from_slice(&port.to_le_bytes());
43 
44                 Ok(7)
45             }
46             V6(addr) => {
47                 let marker = IPV6_MARKER;
48                 let ip: [u8; 16] = addr.ip().octets();
49                 let port: u16 = addr.port();
50                 let flowinfo = addr.flowinfo();
51                 let scope_id = addr.scope_id();
52 
53                 buffer[0] = marker;
54                 buffer[1..17].copy_from_slice(&ip);
55                 buffer[17..19].copy_from_slice(&port.to_le_bytes());
56                 buffer[19..23].copy_from_slice(&flowinfo.to_le_bytes());
57                 buffer[23..27].copy_from_slice(&scope_id.to_le_bytes());
58 
59                 Ok(MAX_ADDR_SIZE)
60             }
61         }
62     }
63 
decode(buffer: &[u8]) -> Result<SocketAddr, Error>64     fn decode(buffer: &[u8]) -> Result<SocketAddr, Error> {
65         use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
66 
67         match buffer[0] {
68             IPV4_MARKER => {
69                 if buffer.len() < IPV4_ADDRESS_SIZE {
70                     return Err(Error::ErrBufferShort);
71                 }
72 
73                 let ip_parts = &buffer[1..5];
74                 let port = match &buffer[5..7].try_into() {
75                     Err(_) => return Err(Error::ErrFailedToParseIpaddr),
76                     Ok(input) => u16::from_le_bytes(*input),
77                 };
78 
79                 let ip = Ipv4Addr::new(ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]);
80 
81                 Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
82             }
83             IPV6_MARKER => {
84                 if buffer.len() < IPV6_ADDRESS_SIZE {
85                     return Err(Error::ErrBufferShort);
86                 }
87 
88                 // Just to help the type system infer correctly
89                 fn helper(b: &[u8]) -> Result<&[u8; 16], TryFromSliceError> {
90                     b.try_into()
91                 }
92 
93                 let ip = match helper(&buffer[1..17]) {
94                     Err(_) => return Err(Error::ErrFailedToParseIpaddr),
95                     Ok(input) => Ipv6Addr::from(*input),
96                 };
97                 let port = match &buffer[17..19].try_into() {
98                     Err(_) => return Err(Error::ErrFailedToParseIpaddr),
99                     Ok(input) => u16::from_le_bytes(*input),
100                 };
101 
102                 let flowinfo = match &buffer[19..23].try_into() {
103                     Err(_) => return Err(Error::ErrFailedToParseIpaddr),
104                     Ok(input) => u32::from_le_bytes(*input),
105                 };
106 
107                 let scope_id = match &buffer[23..27].try_into() {
108                     Err(_) => return Err(Error::ErrFailedToParseIpaddr),
109                     Ok(input) => u32::from_le_bytes(*input),
110                 };
111 
112                 Ok(SocketAddr::V6(SocketAddrV6::new(
113                     ip, port, flowinfo, scope_id,
114                 )))
115             }
116             _ => Err(Error::ErrFailedToParseIpaddr),
117         }
118     }
119 }
120 
121 #[cfg(test)]
122 mod test {
123     use super::*;
124     use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
125 
126     #[test]
test_ipv4()127     fn test_ipv4() {
128         let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234));
129 
130         let mut buffer = [0_u8; MAX_ADDR_SIZE];
131         let encoded_len = ip.encode(&mut buffer);
132 
133         assert_eq!(encoded_len, Ok(7));
134         assert_eq!(
135             &buffer[0..7],
136             &[IPV4_MARKER, 56, 128, 35, 5, 0x34, 0x12][..]
137         );
138 
139         let decoded = SocketAddr::decode(&buffer);
140 
141         assert_eq!(decoded, Ok(ip));
142     }
143 
144     #[test]
test_ipv6()145     fn test_ipv6() {
146         let ip = SocketAddr::V6(SocketAddrV6::new(
147             Ipv6Addr::from([
148                 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123,
149             ]),
150             0x1234,
151             0x12345678,
152             0x87654321,
153         ));
154 
155         let mut buffer = [0_u8; MAX_ADDR_SIZE];
156         let encoded_len = ip.encode(&mut buffer);
157 
158         assert_eq!(encoded_len, Ok(27));
159         assert_eq!(
160             &buffer[0..27],
161             &[
162                 IPV6_MARKER, // marker
163                 // Start of ipv6 address
164                 92,
165                 114,
166                 235,
167                 3,
168                 244,
169                 64,
170                 38,
171                 111,
172                 20,
173                 100,
174                 199,
175                 241,
176                 19,
177                 174,
178                 220,
179                 123,
180                 // LE port
181                 0x34,
182                 0x12,
183                 // LE flowinfo
184                 0x78,
185                 0x56,
186                 0x34,
187                 0x12,
188                 // LE scope_id
189                 0x21,
190                 0x43,
191                 0x65,
192                 0x87,
193             ][..]
194         );
195 
196         let decoded = SocketAddr::decode(&buffer);
197 
198         assert_eq!(decoded, Ok(ip));
199     }
200 
201     #[test]
test_encode_ipv4_with_short_buffer()202     fn test_encode_ipv4_with_short_buffer() {
203         let mut buffer = vec![0u8; IPV4_ADDRESS_SIZE - 1];
204         let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234));
205 
206         let result = ip.encode(&mut buffer);
207 
208         assert_eq!(result, Err(Error::ErrBufferShort));
209     }
210 
211     #[test]
test_encode_ipv6_with_short_buffer()212     fn test_encode_ipv6_with_short_buffer() {
213         let mut buffer = vec![0u8; MAX_ADDR_SIZE - 1];
214         let ip = SocketAddr::V6(SocketAddrV6::new(
215             Ipv6Addr::from([
216                 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123,
217             ]),
218             0x1234,
219             0x12345678,
220             0x87654321,
221         ));
222 
223         let result = ip.encode(&mut buffer);
224 
225         assert_eq!(result, Err(Error::ErrBufferShort));
226     }
227 
228     #[test]
test_decode_ipv4_with_short_buffer()229     fn test_decode_ipv4_with_short_buffer() {
230         let buffer = vec![IPV4_MARKER, 0];
231 
232         let result = SocketAddr::decode(&buffer);
233 
234         assert_eq!(result, Err(Error::ErrBufferShort));
235     }
236 
237     #[test]
test_decode_ipv6_with_short_buffer()238     fn test_decode_ipv6_with_short_buffer() {
239         let buffer = vec![IPV6_MARKER, 0];
240 
241         let result = SocketAddr::decode(&buffer);
242 
243         assert_eq!(result, Err(Error::ErrBufferShort));
244     }
245 }
246