xref: /xiu/library/bytesio/src/bytesio.rs (revision 2f8005f4)
1 use super::bytesio_errors::{BytesIOError, BytesIOErrorValue};
2 
3 use bytes::BufMut;
4 use bytes::Bytes;
5 use bytes::BytesMut;
6 use futures::StreamExt;
7 
8 use std::time::Duration;
9 
10 use tokio::net::TcpStream;
11 use tokio::time::sleep;
12 
13 use futures::SinkExt;
14 use std::time::{SystemTime, UNIX_EPOCH};
15 use tokio_util::codec::BytesCodec;
16 use tokio_util::codec::Framed;
17 
18 use async_trait::async_trait;
19 use std::net::SocketAddr;
20 use tokio::net::UdpSocket;
21 
22 pub enum NetType {
23     TCP,
24     UDP,
25 }
26 
27 #[async_trait]
28 pub trait TNetIO: Send + Sync {
29     async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>;
30     async fn read(&mut self) -> Result<BytesMut, BytesIOError>;
31     async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>;
32     fn get_net_type(&self) -> NetType;
33 }
34 
35 pub struct UdpIO {
36     socket: UdpSocket,
37 }
38 
39 impl UdpIO {
40     pub async fn new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self> {
41         let remote_address = format!("{remote_domain}:{remote_port}");
42         log::info!("remote address: {}", remote_address);
43         let local_address = format!("0.0.0.0:{local_port}");
44         if let Ok(local_socket) = UdpSocket::bind(local_address).await {
45             if let Ok(remote_socket_addr) = remote_address.parse::<SocketAddr>() {
46                 if let Err(err) = local_socket.connect(remote_socket_addr).await {
47                     log::info!("connect to remote udp socket error: {}", err);
48                 }
49             }
50             return Some(Self {
51                 socket: local_socket,
52             });
53         }
54 
55         None
56     }
57     pub fn get_local_port(&self) -> Option<u16> {
58         if let Ok(local_addr) = self.socket.local_addr() {
59             log::info!("local address: {}", local_addr);
60             return Some(local_addr.port());
61         }
62 
63         None
64     }
65 }
66 
67 #[async_trait]
68 impl TNetIO for UdpIO {
69     fn get_net_type(&self) -> NetType {
70         NetType::UDP
71     }
72 
73     async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
74         self.socket.send(bytes.as_ref()).await?;
75         Ok(())
76     }
77 
78     async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
79         let begin_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
80 
81         loop {
82             match self.read().await {
83                 Ok(data) => {
84                     return Ok(data);
85                 }
86                 Err(_) => {
87                     sleep(Duration::from_millis(50)).await;
88                     let current_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
89 
90                     if current_millseconds - begin_millseconds > duration {
91                         return Err(BytesIOError {
92                             value: BytesIOErrorValue::TimeoutError,
93                         });
94                     }
95                 }
96             }
97         }
98     }
99 
100     async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
101         let mut buf = vec![0; 4096];
102         let len = self.socket.recv(&mut buf).await?;
103         let mut rv = BytesMut::new();
104         rv.put(&buf[..len]);
105 
106         Ok(rv)
107     }
108 }
109 
110 pub struct TcpIO {
111     stream: Framed<TcpStream, BytesCodec>,
112     //timeout: Duration,
113 }
114 
115 impl TcpIO {
116     pub fn new(stream: TcpStream) -> Self {
117         Self {
118             stream: Framed::new(stream, BytesCodec::new()),
119             // timeout: ms,
120         }
121     }
122 }
123 
124 #[async_trait]
125 impl TNetIO for TcpIO {
126     fn get_net_type(&self) -> NetType {
127         NetType::TCP
128     }
129 
130     async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
131         self.stream.send(bytes).await?;
132 
133         Ok(())
134     }
135 
136     async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
137         let begin_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
138 
139         loop {
140             match self.read().await {
141                 Ok(data) => {
142                     return Ok(data);
143                 }
144                 Err(_) => {
145                     sleep(Duration::from_millis(50)).await;
146                     let current_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
147 
148                     if current_millseconds - begin_millseconds > duration {
149                         return Err(BytesIOError {
150                             value: BytesIOErrorValue::TimeoutError,
151                         });
152                     }
153                 }
154             }
155         }
156     }
157 
158     async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
159         let message = self.stream.next().await;
160 
161         match message {
162             Some(data) => match data {
163                 Ok(bytes) => Ok(bytes),
164                 Err(err) => Err(BytesIOError {
165                     value: BytesIOErrorValue::IOError(err),
166                 }),
167             },
168             None => Err(BytesIOError {
169                 value: BytesIOErrorValue::NoneReturn,
170             }),
171         }
172     }
173 }
174