1 use std::array::TryFromSliceError; 2 use std::convert::TryInto; 3 use std::net::SocketAddr; 4 5 use util::Error; 6 7 pub(super) trait SocketAddrExt { 8 ///Encode a representation of `self` into the buffer and return the length of this encoded 9 ///version. 10 /// 11 /// The buffer needs to be at least 27 bytes in length. encode(&self, buffer: &mut [u8]) -> Result<usize, Error>12 fn encode(&self, buffer: &mut [u8]) -> Result<usize, Error>; 13 14 /// Decode a `SocketAddr` from a buffer. The encoding should have previously been done with 15 /// [`SocketAddrExt::encode`]. decode(buffer: &[u8]) -> Result<SocketAddr, Error>16 fn decode(buffer: &[u8]) -> Result<SocketAddr, Error>; 17 } 18 19 const IPV4_MARKER: u8 = 4; 20 const IPV4_ADDRESS_SIZE: usize = 7; 21 const IPV6_MARKER: u8 = 6; 22 const IPV6_ADDRESS_SIZE: usize = 27; 23 24 pub(super) const MAX_ADDR_SIZE: usize = IPV6_ADDRESS_SIZE; 25 26 impl SocketAddrExt for SocketAddr { encode(&self, buffer: &mut [u8]) -> Result<usize, Error>27 fn encode(&self, buffer: &mut [u8]) -> Result<usize, Error> { 28 use std::net::SocketAddr::{V4, V6}; 29 30 if buffer.len() < MAX_ADDR_SIZE { 31 return Err(Error::ErrBufferShort); 32 } 33 34 match self { 35 V4(addr) => { 36 let marker = IPV4_MARKER; 37 let ip: [u8; 4] = addr.ip().octets(); 38 let port: u16 = addr.port(); 39 40 buffer[0] = marker; 41 buffer[1..5].copy_from_slice(&ip); 42 buffer[5..7].copy_from_slice(&port.to_le_bytes()); 43 44 Ok(7) 45 } 46 V6(addr) => { 47 let marker = IPV6_MARKER; 48 let ip: [u8; 16] = addr.ip().octets(); 49 let port: u16 = addr.port(); 50 let flowinfo = addr.flowinfo(); 51 let scope_id = addr.scope_id(); 52 53 buffer[0] = marker; 54 buffer[1..17].copy_from_slice(&ip); 55 buffer[17..19].copy_from_slice(&port.to_le_bytes()); 56 buffer[19..23].copy_from_slice(&flowinfo.to_le_bytes()); 57 buffer[23..27].copy_from_slice(&scope_id.to_le_bytes()); 58 59 Ok(MAX_ADDR_SIZE) 60 } 61 } 62 } 63 decode(buffer: &[u8]) -> Result<SocketAddr, Error>64 fn decode(buffer: &[u8]) -> Result<SocketAddr, Error> { 65 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; 66 67 match buffer[0] { 68 IPV4_MARKER => { 69 if buffer.len() < IPV4_ADDRESS_SIZE { 70 return Err(Error::ErrBufferShort); 71 } 72 73 let ip_parts = &buffer[1..5]; 74 let port = match &buffer[5..7].try_into() { 75 Err(_) => return Err(Error::ErrFailedToParseIpaddr), 76 Ok(input) => u16::from_le_bytes(*input), 77 }; 78 79 let ip = Ipv4Addr::new(ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]); 80 81 Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) 82 } 83 IPV6_MARKER => { 84 if buffer.len() < IPV6_ADDRESS_SIZE { 85 return Err(Error::ErrBufferShort); 86 } 87 88 // Just to help the type system infer correctly 89 fn helper(b: &[u8]) -> Result<&[u8; 16], TryFromSliceError> { 90 b.try_into() 91 } 92 93 let ip = match helper(&buffer[1..17]) { 94 Err(_) => return Err(Error::ErrFailedToParseIpaddr), 95 Ok(input) => Ipv6Addr::from(*input), 96 }; 97 let port = match &buffer[17..19].try_into() { 98 Err(_) => return Err(Error::ErrFailedToParseIpaddr), 99 Ok(input) => u16::from_le_bytes(*input), 100 }; 101 102 let flowinfo = match &buffer[19..23].try_into() { 103 Err(_) => return Err(Error::ErrFailedToParseIpaddr), 104 Ok(input) => u32::from_le_bytes(*input), 105 }; 106 107 let scope_id = match &buffer[23..27].try_into() { 108 Err(_) => return Err(Error::ErrFailedToParseIpaddr), 109 Ok(input) => u32::from_le_bytes(*input), 110 }; 111 112 Ok(SocketAddr::V6(SocketAddrV6::new( 113 ip, port, flowinfo, scope_id, 114 ))) 115 } 116 _ => Err(Error::ErrFailedToParseIpaddr), 117 } 118 } 119 } 120 121 #[cfg(test)] 122 mod test { 123 use super::*; 124 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; 125 126 #[test] test_ipv4()127 fn test_ipv4() { 128 let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); 129 130 let mut buffer = [0_u8; MAX_ADDR_SIZE]; 131 let encoded_len = ip.encode(&mut buffer); 132 133 assert_eq!(encoded_len, Ok(7)); 134 assert_eq!( 135 &buffer[0..7], 136 &[IPV4_MARKER, 56, 128, 35, 5, 0x34, 0x12][..] 137 ); 138 139 let decoded = SocketAddr::decode(&buffer); 140 141 assert_eq!(decoded, Ok(ip)); 142 } 143 144 #[test] test_ipv6()145 fn test_ipv6() { 146 let ip = SocketAddr::V6(SocketAddrV6::new( 147 Ipv6Addr::from([ 148 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, 149 ]), 150 0x1234, 151 0x12345678, 152 0x87654321, 153 )); 154 155 let mut buffer = [0_u8; MAX_ADDR_SIZE]; 156 let encoded_len = ip.encode(&mut buffer); 157 158 assert_eq!(encoded_len, Ok(27)); 159 assert_eq!( 160 &buffer[0..27], 161 &[ 162 IPV6_MARKER, // marker 163 // Start of ipv6 address 164 92, 165 114, 166 235, 167 3, 168 244, 169 64, 170 38, 171 111, 172 20, 173 100, 174 199, 175 241, 176 19, 177 174, 178 220, 179 123, 180 // LE port 181 0x34, 182 0x12, 183 // LE flowinfo 184 0x78, 185 0x56, 186 0x34, 187 0x12, 188 // LE scope_id 189 0x21, 190 0x43, 191 0x65, 192 0x87, 193 ][..] 194 ); 195 196 let decoded = SocketAddr::decode(&buffer); 197 198 assert_eq!(decoded, Ok(ip)); 199 } 200 201 #[test] test_encode_ipv4_with_short_buffer()202 fn test_encode_ipv4_with_short_buffer() { 203 let mut buffer = vec![0u8; IPV4_ADDRESS_SIZE - 1]; 204 let ip = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from([56, 128, 35, 5]), 0x1234)); 205 206 let result = ip.encode(&mut buffer); 207 208 assert_eq!(result, Err(Error::ErrBufferShort)); 209 } 210 211 #[test] test_encode_ipv6_with_short_buffer()212 fn test_encode_ipv6_with_short_buffer() { 213 let mut buffer = vec![0u8; MAX_ADDR_SIZE - 1]; 214 let ip = SocketAddr::V6(SocketAddrV6::new( 215 Ipv6Addr::from([ 216 92, 114, 235, 3, 244, 64, 38, 111, 20, 100, 199, 241, 19, 174, 220, 123, 217 ]), 218 0x1234, 219 0x12345678, 220 0x87654321, 221 )); 222 223 let result = ip.encode(&mut buffer); 224 225 assert_eq!(result, Err(Error::ErrBufferShort)); 226 } 227 228 #[test] test_decode_ipv4_with_short_buffer()229 fn test_decode_ipv4_with_short_buffer() { 230 let buffer = vec![IPV4_MARKER, 0]; 231 232 let result = SocketAddr::decode(&buffer); 233 234 assert_eq!(result, Err(Error::ErrBufferShort)); 235 } 236 237 #[test] test_decode_ipv6_with_short_buffer()238 fn test_decode_ipv6_with_short_buffer() { 239 let buffer = vec![IPV6_MARKER, 0]; 240 241 let result = SocketAddr::decode(&buffer); 242 243 assert_eq!(result, Err(Error::ErrBufferShort)); 244 } 245 } 246