1 use std::net::SocketAddr; 2 use std::time::Duration; 3 4 use async_trait::async_trait; 5 use bytes::BufMut; 6 use bytes::Bytes; 7 use bytes::BytesMut; 8 use futures::SinkExt; 9 use futures::StreamExt; 10 use tokio::net::TcpStream; 11 use tokio::net::UdpSocket; 12 use tokio_util::codec::BytesCodec; 13 use tokio_util::codec::Framed; 14 15 use super::bytesio_errors::{BytesIOError, BytesIOErrorValue}; 16 17 pub enum NetType { 18 TCP, 19 UDP, 20 } 21 22 #[async_trait] 23 pub trait TNetIO: Send + Sync { write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>24 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>; read(&mut self) -> Result<BytesMut, BytesIOError>25 async fn read(&mut self) -> Result<BytesMut, BytesIOError>; read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>26 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>; get_net_type(&self) -> NetType27 fn get_net_type(&self) -> NetType; 28 } 29 30 pub struct UdpIO { 31 socket: UdpSocket, 32 } 33 34 impl UdpIO { new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self>35 pub async fn new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self> { 36 let remote_address = format!("{remote_domain}:{remote_port}"); 37 log::info!("remote address: {}", remote_address); 38 let local_address = format!("0.0.0.0:{local_port}"); 39 if let Ok(local_socket) = UdpSocket::bind(local_address).await { 40 if let Ok(remote_socket_addr) = remote_address.parse::<SocketAddr>() { 41 if let Err(err) = local_socket.connect(remote_socket_addr).await { 42 log::info!("connect to remote udp socket error: {}", err); 43 } 44 } 45 return Some(Self { 46 socket: local_socket, 47 }); 48 } 49 50 None 51 } get_local_port(&self) -> Option<u16>52 pub fn get_local_port(&self) -> Option<u16> { 53 if let Ok(local_addr) = self.socket.local_addr() { 54 log::info!("local address: {}", local_addr); 55 return Some(local_addr.port()); 56 } 57 58 None 59 } 60 } 61 62 #[async_trait] 63 impl TNetIO for UdpIO { get_net_type(&self) -> NetType64 fn get_net_type(&self) -> NetType { 65 NetType::UDP 66 } 67 write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>68 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> { 69 self.socket.send(bytes.as_ref()).await?; 70 Ok(()) 71 } 72 read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>73 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> { 74 match tokio::time::timeout(duration, self.read()).await { 75 Ok(data) => data, 76 Err(err) => Err(BytesIOError { 77 value: BytesIOErrorValue::TimeoutError(err), 78 }) 79 } 80 } 81 read(&mut self) -> Result<BytesMut, BytesIOError>82 async fn read(&mut self) -> Result<BytesMut, BytesIOError> { 83 let mut buf = vec![0; 4096]; 84 let len = self.socket.recv(&mut buf).await?; 85 let mut rv = BytesMut::new(); 86 rv.put(&buf[..len]); 87 88 Ok(rv) 89 } 90 } 91 92 pub struct TcpIO { 93 stream: Framed<TcpStream, BytesCodec>, 94 //timeout: Duration, 95 } 96 97 impl TcpIO { new(stream: TcpStream) -> Self98 pub fn new(stream: TcpStream) -> Self { 99 Self { 100 stream: Framed::new(stream, BytesCodec::new()), 101 // timeout: ms, 102 } 103 } 104 } 105 106 #[async_trait] 107 impl TNetIO for TcpIO { get_net_type(&self) -> NetType108 fn get_net_type(&self) -> NetType { 109 NetType::TCP 110 } 111 write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>112 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> { 113 self.stream.send(bytes).await?; 114 115 Ok(()) 116 } 117 read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>118 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> { 119 match tokio::time::timeout(duration, self.read()).await { 120 Ok(data) => data, 121 Err(err) => Err(BytesIOError { 122 value: BytesIOErrorValue::TimeoutError(err), 123 }) 124 } 125 } 126 read(&mut self) -> Result<BytesMut, BytesIOError>127 async fn read(&mut self) -> Result<BytesMut, BytesIOError> { 128 let message = self.stream.next().await; 129 130 match message { 131 Some(data) => match data { 132 Ok(bytes) => Ok(bytes), 133 Err(err) => Err(BytesIOError { 134 value: BytesIOErrorValue::IOError(err), 135 }), 136 }, 137 None => Err(BytesIOError { 138 value: BytesIOErrorValue::NoneReturn, 139 }), 140 } 141 } 142 } 143