1 use super::*; 2 use crate::error::Error; 3 4 // The first byte in a `Message` that specifies its type: 5 pub(crate) const MESSAGE_TYPE_ACK: u8 = 0x02; 6 pub(crate) const MESSAGE_TYPE_OPEN: u8 = 0x03; 7 pub(crate) const MESSAGE_TYPE_LEN: usize = 1; 8 9 type Result<T> = std::result::Result<T, util::Error>; 10 11 // A parsed DataChannel message 12 #[derive(Eq, PartialEq, Copy, Clone, Debug)] 13 pub enum MessageType { 14 DataChannelAck, 15 DataChannelOpen, 16 } 17 18 impl MarshalSize for MessageType { marshal_size(&self) -> usize19 fn marshal_size(&self) -> usize { 20 MESSAGE_TYPE_LEN 21 } 22 } 23 24 impl Marshal for MessageType { marshal_to(&self, mut buf: &mut [u8]) -> Result<usize>25 fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> { 26 let b = match self { 27 MessageType::DataChannelAck => MESSAGE_TYPE_ACK, 28 MessageType::DataChannelOpen => MESSAGE_TYPE_OPEN, 29 }; 30 31 buf.put_u8(b); 32 33 Ok(1) 34 } 35 } 36 37 impl Unmarshal for MessageType { unmarshal<B>(buf: &mut B) -> Result<Self> where B: Buf,38 fn unmarshal<B>(buf: &mut B) -> Result<Self> 39 where 40 B: Buf, 41 { 42 let required_len = MESSAGE_TYPE_LEN; 43 if buf.remaining() < required_len { 44 return Err(Error::UnexpectedEndOfBuffer { 45 expected: required_len, 46 actual: buf.remaining(), 47 } 48 .into()); 49 } 50 51 let b = buf.get_u8(); 52 53 match b { 54 MESSAGE_TYPE_ACK => Ok(Self::DataChannelAck), 55 MESSAGE_TYPE_OPEN => Ok(Self::DataChannelOpen), 56 _ => Err(Error::InvalidMessageType(b).into()), 57 } 58 } 59 } 60 61 #[cfg(test)] 62 mod tests { 63 use super::*; 64 use bytes::{Bytes, BytesMut}; 65 66 #[test] test_message_type_unmarshal_open_success() -> Result<()>67 fn test_message_type_unmarshal_open_success() -> Result<()> { 68 let mut bytes = Bytes::from_static(&[0x03]); 69 let msg_type = MessageType::unmarshal(&mut bytes)?; 70 71 assert_eq!(msg_type, MessageType::DataChannelOpen); 72 73 Ok(()) 74 } 75 76 #[test] test_message_type_unmarshal_ack_success() -> Result<()>77 fn test_message_type_unmarshal_ack_success() -> Result<()> { 78 let mut bytes = Bytes::from_static(&[0x02]); 79 let msg_type = MessageType::unmarshal(&mut bytes)?; 80 81 assert_eq!(msg_type, MessageType::DataChannelAck); 82 Ok(()) 83 } 84 85 #[test] test_message_type_unmarshal_invalid() -> Result<()>86 fn test_message_type_unmarshal_invalid() -> Result<()> { 87 let mut bytes = Bytes::from_static(&[0x01]); 88 match MessageType::unmarshal(&mut bytes) { 89 Ok(_) => panic!("expected Error, but got Ok"), 90 Err(err) => { 91 if let Some(&Error::InvalidMessageType(0x01)) = err.downcast_ref::<Error>() { 92 return Ok(()); 93 } 94 panic!( 95 "unexpected err {:?}, want {:?}", 96 err, 97 Error::InvalidMessageType(0x01) 98 ); 99 } 100 } 101 } 102 103 #[test] test_message_type_marshal_size() -> Result<()>104 fn test_message_type_marshal_size() -> Result<()> { 105 let ack = MessageType::DataChannelAck; 106 let marshal_size = ack.marshal_size(); 107 108 assert_eq!(marshal_size, MESSAGE_TYPE_LEN); 109 Ok(()) 110 } 111 112 #[test] test_message_type_marshal() -> Result<()>113 fn test_message_type_marshal() -> Result<()> { 114 let mut buf = BytesMut::with_capacity(MESSAGE_TYPE_LEN); 115 buf.resize(MESSAGE_TYPE_LEN, 0u8); 116 let msg_type = MessageType::DataChannelAck; 117 let n = msg_type.marshal_to(&mut buf)?; 118 let bytes = buf.freeze(); 119 120 assert_eq!(n, MESSAGE_TYPE_LEN); 121 assert_eq!(&bytes[..], &[0x02]); 122 Ok(()) 123 } 124 } 125