1 #[cfg(test)] 2 mod conn_test; 3 4 use crate::conn::Conn; 5 use crate::error::*; 6 use crate::sync::RwLock; 7 use crate::vnet::chunk::{Chunk, ChunkUdp}; 8 9 use std::net::{IpAddr, SocketAddr}; 10 use tokio::sync::{mpsc, Mutex}; 11 12 use async_trait::async_trait; 13 use std::sync::atomic::{AtomicBool, Ordering}; 14 use std::sync::Arc; 15 16 const MAX_READ_QUEUE_SIZE: usize = 1024; 17 18 /// vNet implements this 19 #[async_trait] 20 pub(crate) trait ConnObserver { write(&self, c: Box<dyn Chunk + Send + Sync>) -> Result<()>21 async fn write(&self, c: Box<dyn Chunk + Send + Sync>) -> Result<()>; on_closed(&self, addr: SocketAddr)22 async fn on_closed(&self, addr: SocketAddr); determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option<IpAddr>23 fn determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option<IpAddr>; 24 } 25 26 pub(crate) type ChunkChTx = mpsc::Sender<Box<dyn Chunk + Send + Sync>>; 27 28 /// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. 29 /// comatible with net.PacketConn and net.Conn 30 pub(crate) struct UdpConn { 31 loc_addr: SocketAddr, 32 rem_addr: RwLock<Option<SocketAddr>>, 33 read_ch_tx: Arc<Mutex<Option<ChunkChTx>>>, 34 read_ch_rx: Mutex<mpsc::Receiver<Box<dyn Chunk + Send + Sync>>>, 35 closed: AtomicBool, 36 obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>, 37 } 38 39 impl UdpConn { new( loc_addr: SocketAddr, rem_addr: Option<SocketAddr>, obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>, ) -> Self40 pub(crate) fn new( 41 loc_addr: SocketAddr, 42 rem_addr: Option<SocketAddr>, 43 obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>, 44 ) -> Self { 45 let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE); 46 47 UdpConn { 48 loc_addr, 49 rem_addr: RwLock::new(rem_addr), 50 read_ch_tx: Arc::new(Mutex::new(Some(read_ch_tx))), 51 read_ch_rx: Mutex::new(read_ch_rx), 52 closed: AtomicBool::new(false), 53 obs, 54 } 55 } 56 get_inbound_ch(&self) -> Arc<Mutex<Option<ChunkChTx>>>57 pub(crate) fn get_inbound_ch(&self) -> Arc<Mutex<Option<ChunkChTx>>> { 58 Arc::clone(&self.read_ch_tx) 59 } 60 } 61 62 #[async_trait] 63 impl Conn for UdpConn { connect(&self, addr: SocketAddr) -> Result<()>64 async fn connect(&self, addr: SocketAddr) -> Result<()> { 65 self.rem_addr.write().replace(addr); 66 67 Ok(()) 68 } recv(&self, buf: &mut [u8]) -> Result<usize>69 async fn recv(&self, buf: &mut [u8]) -> Result<usize> { 70 let (n, _) = self.recv_from(buf).await?; 71 Ok(n) 72 } 73 74 /// recv_from reads a packet from the connection, 75 /// copying the payload into p. It returns the number of 76 /// bytes copied into p and the return address that 77 /// was on the packet. 78 /// It returns the number of bytes read (0 <= n <= len(p)) 79 /// and any error encountered. Callers should always process 80 /// the n > 0 bytes returned before considering the error err. recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>81 async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { 82 let mut read_ch = self.read_ch_rx.lock().await; 83 let rem_addr = *self.rem_addr.read(); 84 while let Some(chunk) = read_ch.recv().await { 85 let user_data = chunk.user_data(); 86 let n = std::cmp::min(buf.len(), user_data.len()); 87 buf[..n].copy_from_slice(&user_data[..n]); 88 let addr = chunk.source_addr(); 89 { 90 if let Some(rem_addr) = &rem_addr { 91 if &addr != rem_addr { 92 continue; // discard (shouldn't happen) 93 } 94 } 95 } 96 return Ok((n, addr)); 97 } 98 99 Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Connection Aborted").into()) 100 } 101 send(&self, buf: &[u8]) -> Result<usize>102 async fn send(&self, buf: &[u8]) -> Result<usize> { 103 let rem_addr = *self.rem_addr.read(); 104 if let Some(rem_addr) = rem_addr { 105 self.send_to(buf, rem_addr).await 106 } else { 107 Err(Error::ErrNoRemAddr) 108 } 109 } 110 111 /// send_to writes a packet with payload p to addr. 112 /// send_to can be made to time out and return send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize>113 async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> { 114 let src_ip = { 115 let obs = self.obs.lock().await; 116 match obs.determine_source_ip(self.loc_addr.ip(), target.ip()) { 117 Some(ip) => ip, 118 None => return Err(Error::ErrLocAddr), 119 } 120 }; 121 122 let src_addr = SocketAddr::new(src_ip, self.loc_addr.port()); 123 124 let mut chunk = ChunkUdp::new(src_addr, target); 125 chunk.user_data = buf.to_vec(); 126 { 127 let c: Box<dyn Chunk + Send + Sync> = Box::new(chunk); 128 let obs = self.obs.lock().await; 129 obs.write(c).await? 130 } 131 132 Ok(buf.len()) 133 } 134 local_addr(&self) -> Result<SocketAddr>135 fn local_addr(&self) -> Result<SocketAddr> { 136 Ok(self.loc_addr) 137 } 138 remote_addr(&self) -> Option<SocketAddr>139 fn remote_addr(&self) -> Option<SocketAddr> { 140 *self.rem_addr.read() 141 } 142 close(&self) -> Result<()>143 async fn close(&self) -> Result<()> { 144 if self.closed.load(Ordering::SeqCst) { 145 return Err(Error::ErrAlreadyClosed); 146 } 147 self.closed.store(true, Ordering::SeqCst); 148 { 149 let mut reach_ch = self.read_ch_tx.lock().await; 150 reach_ch.take(); 151 } 152 { 153 let obs = self.obs.lock().await; 154 obs.on_closed(self.loc_addr).await; 155 } 156 157 Ok(()) 158 } 159 } 160