xref: /webrtc/data/src/message/message_type.rs (revision ffe74184)
1 use super::*;
2 use crate::error::Error;
3 
4 // The first byte in a `Message` that specifies its type:
5 pub(crate) const MESSAGE_TYPE_ACK: u8 = 0x02;
6 pub(crate) const MESSAGE_TYPE_OPEN: u8 = 0x03;
7 pub(crate) const MESSAGE_TYPE_LEN: usize = 1;
8 
9 type Result<T> = std::result::Result<T, util::Error>;
10 
11 // A parsed DataChannel message
12 #[derive(Eq, PartialEq, Copy, Clone, Debug)]
13 pub enum MessageType {
14     DataChannelAck,
15     DataChannelOpen,
16 }
17 
18 impl MarshalSize for MessageType {
marshal_size(&self) -> usize19     fn marshal_size(&self) -> usize {
20         MESSAGE_TYPE_LEN
21     }
22 }
23 
24 impl Marshal for MessageType {
marshal_to(&self, mut buf: &mut [u8]) -> Result<usize>25     fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
26         let b = match self {
27             MessageType::DataChannelAck => MESSAGE_TYPE_ACK,
28             MessageType::DataChannelOpen => MESSAGE_TYPE_OPEN,
29         };
30 
31         buf.put_u8(b);
32 
33         Ok(1)
34     }
35 }
36 
37 impl Unmarshal for MessageType {
unmarshal<B>(buf: &mut B) -> Result<Self> where B: Buf,38     fn unmarshal<B>(buf: &mut B) -> Result<Self>
39     where
40         B: Buf,
41     {
42         let required_len = MESSAGE_TYPE_LEN;
43         if buf.remaining() < required_len {
44             return Err(Error::UnexpectedEndOfBuffer {
45                 expected: required_len,
46                 actual: buf.remaining(),
47             }
48             .into());
49         }
50 
51         let b = buf.get_u8();
52 
53         match b {
54             MESSAGE_TYPE_ACK => Ok(Self::DataChannelAck),
55             MESSAGE_TYPE_OPEN => Ok(Self::DataChannelOpen),
56             _ => Err(Error::InvalidMessageType(b).into()),
57         }
58     }
59 }
60 
61 #[cfg(test)]
62 mod tests {
63     use super::*;
64     use bytes::{Bytes, BytesMut};
65 
66     #[test]
test_message_type_unmarshal_open_success() -> Result<()>67     fn test_message_type_unmarshal_open_success() -> Result<()> {
68         let mut bytes = Bytes::from_static(&[0x03]);
69         let msg_type = MessageType::unmarshal(&mut bytes)?;
70 
71         assert_eq!(msg_type, MessageType::DataChannelOpen);
72 
73         Ok(())
74     }
75 
76     #[test]
test_message_type_unmarshal_ack_success() -> Result<()>77     fn test_message_type_unmarshal_ack_success() -> Result<()> {
78         let mut bytes = Bytes::from_static(&[0x02]);
79         let msg_type = MessageType::unmarshal(&mut bytes)?;
80 
81         assert_eq!(msg_type, MessageType::DataChannelAck);
82         Ok(())
83     }
84 
85     #[test]
test_message_type_unmarshal_invalid() -> Result<()>86     fn test_message_type_unmarshal_invalid() -> Result<()> {
87         let mut bytes = Bytes::from_static(&[0x01]);
88         match MessageType::unmarshal(&mut bytes) {
89             Ok(_) => panic!("expected Error, but got Ok"),
90             Err(err) => {
91                 if let Some(&Error::InvalidMessageType(0x01)) = err.downcast_ref::<Error>() {
92                     return Ok(());
93                 }
94                 panic!(
95                     "unexpected err {:?}, want {:?}",
96                     err,
97                     Error::InvalidMessageType(0x01)
98                 );
99             }
100         }
101     }
102 
103     #[test]
test_message_type_marshal_size() -> Result<()>104     fn test_message_type_marshal_size() -> Result<()> {
105         let ack = MessageType::DataChannelAck;
106         let marshal_size = ack.marshal_size();
107 
108         assert_eq!(marshal_size, MESSAGE_TYPE_LEN);
109         Ok(())
110     }
111 
112     #[test]
test_message_type_marshal() -> Result<()>113     fn test_message_type_marshal() -> Result<()> {
114         let mut buf = BytesMut::with_capacity(MESSAGE_TYPE_LEN);
115         buf.resize(MESSAGE_TYPE_LEN, 0u8);
116         let msg_type = MessageType::DataChannelAck;
117         let n = msg_type.marshal_to(&mut buf)?;
118         let bytes = buf.freeze();
119 
120         assert_eq!(n, MESSAGE_TYPE_LEN);
121         assert_eq!(&bytes[..], &[0x02]);
122         Ok(())
123     }
124 }
125