xref: /xiu/library/bytesio/src/bytesio.rs (revision 46f4d48b)
1 use std::net::SocketAddr;
2 use std::time::Duration;
3 
4 use async_trait::async_trait;
5 use bytes::BufMut;
6 use bytes::Bytes;
7 use bytes::BytesMut;
8 use futures::SinkExt;
9 use futures::StreamExt;
10 use tokio::net::TcpStream;
11 use tokio::net::UdpSocket;
12 use tokio_util::codec::BytesCodec;
13 use tokio_util::codec::Framed;
14 
15 use super::bytesio_errors::{BytesIOError, BytesIOErrorValue};
16 
17 pub enum NetType {
18     TCP,
19     UDP,
20 }
21 
22 #[async_trait]
23 pub trait TNetIO: Send + Sync {
write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>24     async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>;
read(&mut self) -> Result<BytesMut, BytesIOError>25     async fn read(&mut self) -> Result<BytesMut, BytesIOError>;
read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>26     async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>;
get_net_type(&self) -> NetType27     fn get_net_type(&self) -> NetType;
28 }
29 
30 pub struct UdpIO {
31     socket: UdpSocket,
32 }
33 
34 impl UdpIO {
new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self>35     pub async fn new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self> {
36         let remote_address = format!("{remote_domain}:{remote_port}");
37         log::info!("remote address: {}", remote_address);
38         let local_address = format!("0.0.0.0:{local_port}");
39         if let Ok(local_socket) = UdpSocket::bind(local_address).await {
40             if let Ok(remote_socket_addr) = remote_address.parse::<SocketAddr>() {
41                 if let Err(err) = local_socket.connect(remote_socket_addr).await {
42                     log::info!("connect to remote udp socket error: {}", err);
43                 }
44             }
45             return Some(Self {
46                 socket: local_socket,
47             });
48         }
49 
50         None
51     }
get_local_port(&self) -> Option<u16>52     pub fn get_local_port(&self) -> Option<u16> {
53         if let Ok(local_addr) = self.socket.local_addr() {
54             log::info!("local address: {}", local_addr);
55             return Some(local_addr.port());
56         }
57 
58         None
59     }
60 }
61 
62 #[async_trait]
63 impl TNetIO for UdpIO {
get_net_type(&self) -> NetType64     fn get_net_type(&self) -> NetType {
65         NetType::UDP
66     }
67 
write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>68     async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
69         self.socket.send(bytes.as_ref()).await?;
70         Ok(())
71     }
72 
read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>73     async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
74         match tokio::time::timeout(duration, self.read()).await {
75             Ok(data) => data,
76             Err(err) => Err(BytesIOError {
77                 value: BytesIOErrorValue::TimeoutError(err),
78             })
79         }
80     }
81 
read(&mut self) -> Result<BytesMut, BytesIOError>82     async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
83         let mut buf = vec![0; 4096];
84         let len = self.socket.recv(&mut buf).await?;
85         let mut rv = BytesMut::new();
86         rv.put(&buf[..len]);
87 
88         Ok(rv)
89     }
90 }
91 
92 pub struct TcpIO {
93     stream: Framed<TcpStream, BytesCodec>,
94     //timeout: Duration,
95 }
96 
97 impl TcpIO {
new(stream: TcpStream) -> Self98     pub fn new(stream: TcpStream) -> Self {
99         Self {
100             stream: Framed::new(stream, BytesCodec::new()),
101             // timeout: ms,
102         }
103     }
104 }
105 
106 #[async_trait]
107 impl TNetIO for TcpIO {
get_net_type(&self) -> NetType108     fn get_net_type(&self) -> NetType {
109         NetType::TCP
110     }
111 
write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>112     async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
113         self.stream.send(bytes).await?;
114 
115         Ok(())
116     }
117 
read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>118     async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
119         match tokio::time::timeout(duration, self.read()).await {
120             Ok(data) => data,
121             Err(err) => Err(BytesIOError {
122                 value: BytesIOErrorValue::TimeoutError(err),
123             })
124         }
125     }
126 
read(&mut self) -> Result<BytesMut, BytesIOError>127     async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
128         let message = self.stream.next().await;
129 
130         match message {
131             Some(data) => match data {
132                 Ok(bytes) => Ok(bytes),
133                 Err(err) => Err(BytesIOError {
134                     value: BytesIOErrorValue::IOError(err),
135                 }),
136             },
137             None => Err(BytesIOError {
138                 value: BytesIOErrorValue::NoneReturn,
139             }),
140         }
141     }
142 }
143