1 use super::bytesio_errors::{BytesIOError, BytesIOErrorValue}; 2 3 use bytes::BufMut; 4 use bytes::Bytes; 5 use bytes::BytesMut; 6 use futures::StreamExt; 7 8 use std::time::Duration; 9 10 use tokio::net::TcpStream; 11 use tokio::time::sleep; 12 13 use futures::SinkExt; 14 use std::time::{SystemTime, UNIX_EPOCH}; 15 use tokio_util::codec::BytesCodec; 16 use tokio_util::codec::Framed; 17 18 use async_trait::async_trait; 19 use std::net::SocketAddr; 20 use tokio::net::UdpSocket; 21 22 pub enum NetType { 23 TCP, 24 UDP, 25 } 26 27 #[async_trait] 28 pub trait TNetIO: Send + Sync { 29 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>; 30 async fn read(&mut self) -> Result<BytesMut, BytesIOError>; 31 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>; 32 fn get_net_type(&self) -> NetType; 33 } 34 35 pub struct UdpIO { 36 socket: UdpSocket, 37 } 38 39 impl UdpIO { 40 pub async fn new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self> { 41 let remote_address = format!("{remote_domain}:{remote_port}"); 42 log::info!("remote address: {}", remote_address); 43 let local_address = format!("0.0.0.0:{local_port}"); 44 if let Ok(local_socket) = UdpSocket::bind(local_address).await { 45 if let Ok(remote_socket_addr) = remote_address.parse::<SocketAddr>() { 46 if let Err(err) = local_socket.connect(remote_socket_addr).await { 47 log::info!("connect to remote udp socket error: {}", err); 48 } 49 } 50 return Some(Self { 51 socket: local_socket, 52 }); 53 } 54 55 None 56 } 57 pub fn get_local_port(&self) -> Option<u16> { 58 if let Ok(local_addr) = self.socket.local_addr() { 59 log::info!("local address: {}", local_addr); 60 return Some(local_addr.port()); 61 } 62 63 None 64 } 65 } 66 67 #[async_trait] 68 impl TNetIO for UdpIO { 69 fn get_net_type(&self) -> NetType { 70 NetType::UDP 71 } 72 73 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> { 74 self.socket.send(bytes.as_ref()).await?; 75 Ok(()) 76 } 77 78 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> { 79 let begin_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); 80 81 loop { 82 match self.read().await { 83 Ok(data) => { 84 return Ok(data); 85 } 86 Err(_) => { 87 sleep(Duration::from_millis(50)).await; 88 let current_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); 89 90 if current_millseconds - begin_millseconds > duration { 91 return Err(BytesIOError { 92 value: BytesIOErrorValue::TimeoutError, 93 }); 94 } 95 } 96 } 97 } 98 } 99 100 async fn read(&mut self) -> Result<BytesMut, BytesIOError> { 101 let mut buf = vec![0; 4096]; 102 let len = self.socket.recv(&mut buf).await?; 103 let mut rv = BytesMut::new(); 104 rv.put(&buf[..len]); 105 106 Ok(rv) 107 } 108 } 109 110 pub struct TcpIO { 111 stream: Framed<TcpStream, BytesCodec>, 112 //timeout: Duration, 113 } 114 115 impl TcpIO { 116 pub fn new(stream: TcpStream) -> Self { 117 Self { 118 stream: Framed::new(stream, BytesCodec::new()), 119 // timeout: ms, 120 } 121 } 122 } 123 124 #[async_trait] 125 impl TNetIO for TcpIO { 126 fn get_net_type(&self) -> NetType { 127 NetType::TCP 128 } 129 130 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> { 131 self.stream.send(bytes).await?; 132 133 Ok(()) 134 } 135 136 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> { 137 let begin_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); 138 139 loop { 140 match self.read().await { 141 Ok(data) => { 142 return Ok(data); 143 } 144 Err(_) => { 145 sleep(Duration::from_millis(50)).await; 146 let current_millseconds = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); 147 148 if current_millseconds - begin_millseconds > duration { 149 return Err(BytesIOError { 150 value: BytesIOErrorValue::TimeoutError, 151 }); 152 } 153 } 154 } 155 } 156 } 157 158 async fn read(&mut self) -> Result<BytesMut, BytesIOError> { 159 let message = self.stream.next().await; 160 161 match message { 162 Some(data) => match data { 163 Ok(bytes) => Ok(bytes), 164 Err(err) => Err(BytesIOError { 165 value: BytesIOErrorValue::IOError(err), 166 }), 167 }, 168 None => Err(BytesIOError { 169 value: BytesIOErrorValue::NoneReturn, 170 }), 171 } 172 } 173 } 174