1 use {
2     super::{
3         define, define::ServerHandshakeState, digest::DigestProcessor, errors::HandshakeError,
4         handshake_trait::THandshakeServer, utils,
5     },
6     byteorder::BigEndian,
7     bytes::BytesMut,
8     bytesio::{
9         bytes_reader::BytesReader, bytes_writer::AsyncBytesWriter, bytes_writer::BytesWriter,
10         bytesio::TNetIO,
11     },
12     std::sync::Arc,
13     tokio::sync::Mutex,
14 };
15 
16 pub struct SimpleHandshakeServer {
17     pub reader: BytesReader,
18     pub writer: AsyncBytesWriter,
19     pub state: ServerHandshakeState,
20 
21     c1_bytes: BytesMut,
22     c1_timestamp: u32,
23 }
24 
25 pub struct ComplexHandshakeServer {
26     pub reader: BytesReader,
27     pub writer: AsyncBytesWriter,
28     pub state: ServerHandshakeState,
29 
30     c1_digest: BytesMut,
31     c1_timestamp: u32,
32 }
33 
34 impl SimpleHandshakeServer {
new(io: Arc<Mutex<Box<dyn TNetIO + Send + Sync>>>) -> Self35     pub fn new(io: Arc<Mutex<Box<dyn TNetIO + Send + Sync>>>) -> Self {
36         Self {
37             reader: BytesReader::new(BytesMut::new()),
38             writer: AsyncBytesWriter::new(io),
39             state: ServerHandshakeState::ReadC0C1,
40 
41             c1_bytes: BytesMut::new(),
42             c1_timestamp: 0,
43         }
44     }
extend_data(&mut self, data: &[u8])45     pub fn extend_data(&mut self, data: &[u8]) {
46         self.reader.extend_from_slice(data);
47     }
48 
handshake(&mut self) -> Result<(), HandshakeError>49     pub async fn handshake(&mut self) -> Result<(), HandshakeError> {
50         loop {
51             match self.state {
52                 ServerHandshakeState::ReadC0C1 => {
53                     log::info!("[ S<-C ] [simple handshake] read C0C1");
54                     self.read_c0()?;
55                     self.read_c1()?;
56                     self.state = ServerHandshakeState::WriteS0S1S2;
57                 }
58 
59                 ServerHandshakeState::WriteS0S1S2 => {
60                     log::info!("[ S->C ] [simple handshake] write S0S1S2");
61                     self.write_s0()?;
62                     self.write_s1()?;
63                     self.write_s2()?;
64                     self.writer.flush().await?;
65                     self.state = ServerHandshakeState::ReadC2;
66                     break;
67                 }
68 
69                 ServerHandshakeState::ReadC2 => {
70                     log::info!("[ S<-C ] [simple handshake] read C2");
71                     self.read_c2()?;
72                     self.state = ServerHandshakeState::Finish;
73                 }
74 
75                 ServerHandshakeState::Finish => {
76                     log::info!("simple handshake successfully..");
77                     break;
78                 }
79             }
80         }
81 
82         Ok(())
83     }
84 }
85 
86 impl ComplexHandshakeServer {
new(io: Arc<Mutex<Box<dyn TNetIO + Send + Sync>>>) -> Self87     pub fn new(io: Arc<Mutex<Box<dyn TNetIO + Send + Sync>>>) -> Self {
88         Self {
89             reader: BytesReader::new(BytesMut::new()),
90             writer: AsyncBytesWriter::new(io),
91             state: ServerHandshakeState::ReadC0C1,
92 
93             c1_digest: BytesMut::new(),
94             c1_timestamp: 0,
95         }
96     }
97 
extend_data(&mut self, data: &[u8])98     pub fn extend_data(&mut self, data: &[u8]) {
99         self.reader.extend_from_slice(data);
100     }
101 
handshake(&mut self) -> Result<(), HandshakeError>102     pub async fn handshake(&mut self) -> Result<(), HandshakeError> {
103         loop {
104             match self.state {
105                 ServerHandshakeState::ReadC0C1 => {
106                     log::info!("[ S<-C ] [complex handshake] read C0C1");
107                     self.read_c0()?;
108                     self.read_c1()?;
109                     self.state = ServerHandshakeState::WriteS0S1S2;
110                 }
111 
112                 ServerHandshakeState::WriteS0S1S2 => {
113                     log::info!("[ S->C ] [complex handshake] write S0S1S2");
114                     self.write_s0()?;
115                     self.write_s1()?;
116                     self.write_s2()?;
117                     self.writer.flush().await?;
118                     log::info!("[ S->C ] [complex handshake] write S0S1S2 finish");
119                     self.state = ServerHandshakeState::ReadC2;
120                     break;
121                 }
122 
123                 ServerHandshakeState::ReadC2 => {
124                     log::info!("[ S<-C ] [complex handshake] read C2");
125                     self.read_c2()?;
126                     self.state = ServerHandshakeState::Finish;
127                 }
128 
129                 ServerHandshakeState::Finish => {
130                     log::info!("complex handshake successfully..");
131                     break;
132                 }
133             }
134         }
135 
136         Ok(())
137     }
138 }
139 
140 impl THandshakeServer for SimpleHandshakeServer {
read_c0(&mut self) -> Result<(), HandshakeError>141     fn read_c0(&mut self) -> Result<(), HandshakeError> {
142         self.reader.read_u8()?;
143         Ok(())
144     }
145 
read_c1(&mut self) -> Result<(), HandshakeError>146     fn read_c1(&mut self) -> Result<(), HandshakeError> {
147         let c1_bytes = self.reader.read_bytes(define::RTMP_HANDSHAKE_SIZE)?;
148         self.c1_bytes = c1_bytes.clone();
149 
150         let mut reader = BytesReader::new(c1_bytes);
151         self.c1_timestamp = reader.read_u32::<BigEndian>()?;
152 
153         Ok(())
154     }
155 
read_c2(&mut self) -> Result<(), HandshakeError>156     fn read_c2(&mut self) -> Result<(), HandshakeError> {
157         self.reader.read_bytes(define::RTMP_HANDSHAKE_SIZE)?;
158         Ok(())
159     }
160 
write_s0(&mut self) -> Result<(), HandshakeError>161     fn write_s0(&mut self) -> Result<(), HandshakeError> {
162         self.writer.write_u8(define::RTMP_VERSION as u8)?;
163         Ok(())
164     }
165 
write_s1(&mut self) -> Result<(), HandshakeError>166     fn write_s1(&mut self) -> Result<(), HandshakeError> {
167         self.writer.write_u32::<BigEndian>(utils::current_time())?;
168 
169         let timestamp = self.c1_timestamp;
170         self.writer.write_u32::<BigEndian>(timestamp)?;
171 
172         self.writer
173             .write_random_bytes(define::RTMP_HANDSHAKE_SIZE as u32 - 8)?;
174         Ok(())
175     }
176 
write_s2(&mut self) -> Result<(), HandshakeError>177     fn write_s2(&mut self) -> Result<(), HandshakeError> {
178         let data = self.c1_bytes.clone();
179         self.writer.write(&data[..])?;
180         Ok(())
181     }
182 }
183 
184 impl THandshakeServer for ComplexHandshakeServer {
read_c0(&mut self) -> Result<(), HandshakeError>185     fn read_c0(&mut self) -> Result<(), HandshakeError> {
186         self.reader.read_u8()?;
187         Ok(())
188     }
189 
read_c1(&mut self) -> Result<(), HandshakeError>190     fn read_c1(&mut self) -> Result<(), HandshakeError> {
191         let c1_bytes = self.reader.read_bytes(define::RTMP_HANDSHAKE_SIZE)?;
192 
193         /*read the timestamp*/
194         self.c1_timestamp = BytesReader::new(c1_bytes.clone()).read_u32::<BigEndian>()?;
195 
196         /*read the digest and save*/
197         let mut key = BytesMut::new();
198         key.extend_from_slice(define::RTMP_CLIENT_KEY_FIRST_HALF.as_bytes());
199 
200         let mut digest_processor = DigestProcessor::new(c1_bytes, key);
201         let (digest_content, _) = digest_processor.read_digest()?;
202 
203         self.c1_digest = digest_content;
204 
205         Ok(())
206     }
207 
read_c2(&mut self) -> Result<(), HandshakeError>208     fn read_c2(&mut self) -> Result<(), HandshakeError> {
209         self.reader.read_bytes(define::RTMP_HANDSHAKE_SIZE)?;
210         Ok(())
211     }
212 
write_s0(&mut self) -> Result<(), HandshakeError>213     fn write_s0(&mut self) -> Result<(), HandshakeError> {
214         self.writer.write_u8(define::RTMP_VERSION as u8)?;
215         Ok(())
216     }
217 
write_s1(&mut self) -> Result<(), HandshakeError>218     fn write_s1(&mut self) -> Result<(), HandshakeError> {
219         /*write the s1 data*/
220         let mut writer = BytesWriter::new();
221 
222         writer.write_u32::<BigEndian>(utils::current_time())?;
223         writer.write(&define::RTMP_SERVER_VERSION)?;
224         writer.write_random_bytes(define::RTMP_HANDSHAKE_SIZE as u32 - 8)?;
225 
226         /*generate the digest*/
227         let mut key = BytesMut::new();
228         key.extend_from_slice(define::RTMP_SERVER_KEY_FIRST_HALF.as_bytes());
229 
230         let mut digest_processor = DigestProcessor::new(writer.extract_current_bytes(), key);
231         let content = digest_processor.generate_and_fill_digest()?;
232 
233         /*write*/
234         self.writer.write(&content[..])?;
235         Ok(())
236     }
237 
write_s2(&mut self) -> Result<(), HandshakeError>238     fn write_s2(&mut self) -> Result<(), HandshakeError> {
239         /*write the s2 data*/
240         let mut writer = BytesWriter::new();
241 
242         writer.write_u32::<BigEndian>(utils::current_time())?;
243         writer.write_u32::<BigEndian>(self.c1_timestamp)?;
244         writer.write_random_bytes(define::RTMP_HANDSHAKE_SIZE as u32 - 8)?;
245 
246         /*generate the key for s2*/
247         let mut key = BytesMut::new();
248         key.extend_from_slice(&define::RTMP_SERVER_KEY);
249 
250         let mut digest_processor = DigestProcessor::new(BytesMut::new(), key);
251         let tmp_key = digest_processor.make_digest(Vec::from(&self.c1_digest[..]))?;
252 
253         /*generate the digest for s2 data*/
254         let mut data: BytesMut = BytesMut::new();
255         data.extend_from_slice(&writer.get_current_bytes()[..1504]);
256 
257         let mut digest_processor_2 = DigestProcessor::new(BytesMut::new(), tmp_key);
258         let digtest = digest_processor_2.make_digest(Vec::from(&data[..]))?;
259 
260         let content = [data, digtest].concat();
261 
262         /*write*/
263         self.writer.write(&content[..])?;
264 
265         Ok(())
266     }
267 }
268 
269 pub struct HandshakeServer {
270     simple_handshaker: SimpleHandshakeServer,
271     complex_handshaker: ComplexHandshakeServer,
272     is_complex: bool,
273 
274     saved_data: BytesMut,
275 }
276 
277 impl HandshakeServer {
new(io: Arc<Mutex<Box<dyn TNetIO + Send + Sync>>>) -> Self278     pub fn new(io: Arc<Mutex<Box<dyn TNetIO + Send + Sync>>>) -> Self {
279         Self {
280             simple_handshaker: SimpleHandshakeServer::new(io.clone()),
281             complex_handshaker: ComplexHandshakeServer::new(io),
282             is_complex: true,
283 
284             saved_data: BytesMut::new(),
285         }
286     }
287 
extend_data(&mut self, data: &[u8])288     pub fn extend_data(&mut self, data: &[u8]) {
289         if self.is_complex {
290             self.complex_handshaker.extend_data(data);
291             self.saved_data.extend_from_slice(data);
292         } else {
293             self.simple_handshaker.extend_data(data);
294         }
295     }
296 
state(&mut self) -> ServerHandshakeState297     pub fn state(&mut self) -> ServerHandshakeState {
298         if self.is_complex {
299             self.complex_handshaker.state
300         } else {
301             self.simple_handshaker.state
302         }
303     }
304 
get_remaining_bytes(&mut self) -> BytesMut305     pub fn get_remaining_bytes(&mut self) -> BytesMut {
306         match self.is_complex {
307             true => self.complex_handshaker.reader.get_remaining_bytes(),
308             false => self.simple_handshaker.reader.get_remaining_bytes(),
309         }
310     }
handshake(&mut self) -> Result<(), HandshakeError>311     pub async fn handshake(&mut self) -> Result<(), HandshakeError> {
312         match self.is_complex {
313             true => {
314                 let result = self.complex_handshaker.handshake().await;
315                 match result {
316                     Ok(_) => {
317                         //println!("Complex handshake is successfully!!")
318                     }
319                     Err(err) => {
320                         log::warn!("complex handshake failed.. err:{}", err);
321                         self.is_complex = false;
322                         let data = self.saved_data.clone();
323                         self.extend_data(&data[..]);
324                         self.simple_handshaker.handshake().await?;
325                     }
326                 }
327             }
328             false => {
329                 self.simple_handshaker.handshake().await?;
330             }
331         }
332 
333         Ok(())
334     }
335 }
336