xref: /webrtc/util/src/vnet/conn.rs (revision 5b79f08a)
1 #[cfg(test)]
2 mod conn_test;
3 
4 use crate::conn::Conn;
5 use crate::error::*;
6 use crate::sync::RwLock;
7 use crate::vnet::chunk::{Chunk, ChunkUdp};
8 
9 use std::net::{IpAddr, SocketAddr};
10 use tokio::sync::{mpsc, Mutex};
11 
12 use async_trait::async_trait;
13 use std::sync::atomic::{AtomicBool, Ordering};
14 use std::sync::Arc;
15 
16 const MAX_READ_QUEUE_SIZE: usize = 1024;
17 
18 /// vNet implements this
19 #[async_trait]
20 pub(crate) trait ConnObserver {
write(&self, c: Box<dyn Chunk + Send + Sync>) -> Result<()>21     async fn write(&self, c: Box<dyn Chunk + Send + Sync>) -> Result<()>;
on_closed(&self, addr: SocketAddr)22     async fn on_closed(&self, addr: SocketAddr);
determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option<IpAddr>23     fn determine_source_ip(&self, loc_ip: IpAddr, dst_ip: IpAddr) -> Option<IpAddr>;
24 }
25 
26 pub(crate) type ChunkChTx = mpsc::Sender<Box<dyn Chunk + Send + Sync>>;
27 
28 /// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
29 /// comatible with net.PacketConn and net.Conn
30 pub(crate) struct UdpConn {
31     loc_addr: SocketAddr,
32     rem_addr: RwLock<Option<SocketAddr>>,
33     read_ch_tx: Arc<Mutex<Option<ChunkChTx>>>,
34     read_ch_rx: Mutex<mpsc::Receiver<Box<dyn Chunk + Send + Sync>>>,
35     closed: AtomicBool,
36     obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>,
37 }
38 
39 impl UdpConn {
new( loc_addr: SocketAddr, rem_addr: Option<SocketAddr>, obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>, ) -> Self40     pub(crate) fn new(
41         loc_addr: SocketAddr,
42         rem_addr: Option<SocketAddr>,
43         obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>,
44     ) -> Self {
45         let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE);
46 
47         UdpConn {
48             loc_addr,
49             rem_addr: RwLock::new(rem_addr),
50             read_ch_tx: Arc::new(Mutex::new(Some(read_ch_tx))),
51             read_ch_rx: Mutex::new(read_ch_rx),
52             closed: AtomicBool::new(false),
53             obs,
54         }
55     }
56 
get_inbound_ch(&self) -> Arc<Mutex<Option<ChunkChTx>>>57     pub(crate) fn get_inbound_ch(&self) -> Arc<Mutex<Option<ChunkChTx>>> {
58         Arc::clone(&self.read_ch_tx)
59     }
60 }
61 
62 #[async_trait]
63 impl Conn for UdpConn {
connect(&self, addr: SocketAddr) -> Result<()>64     async fn connect(&self, addr: SocketAddr) -> Result<()> {
65         self.rem_addr.write().replace(addr);
66 
67         Ok(())
68     }
recv(&self, buf: &mut [u8]) -> Result<usize>69     async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
70         let (n, _) = self.recv_from(buf).await?;
71         Ok(n)
72     }
73 
74     /// recv_from reads a packet from the connection,
75     /// copying the payload into p. It returns the number of
76     /// bytes copied into p and the return address that
77     /// was on the packet.
78     /// It returns the number of bytes read (0 <= n <= len(p))
79     /// and any error encountered. Callers should always process
80     /// the n > 0 bytes returned before considering the error err.
recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>81     async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
82         let mut read_ch = self.read_ch_rx.lock().await;
83         let rem_addr = *self.rem_addr.read();
84         while let Some(chunk) = read_ch.recv().await {
85             let user_data = chunk.user_data();
86             let n = std::cmp::min(buf.len(), user_data.len());
87             buf[..n].copy_from_slice(&user_data[..n]);
88             let addr = chunk.source_addr();
89             {
90                 if let Some(rem_addr) = &rem_addr {
91                     if &addr != rem_addr {
92                         continue; // discard (shouldn't happen)
93                     }
94                 }
95             }
96             return Ok((n, addr));
97         }
98 
99         Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "Connection Aborted").into())
100     }
101 
send(&self, buf: &[u8]) -> Result<usize>102     async fn send(&self, buf: &[u8]) -> Result<usize> {
103         let rem_addr = *self.rem_addr.read();
104         if let Some(rem_addr) = rem_addr {
105             self.send_to(buf, rem_addr).await
106         } else {
107             Err(Error::ErrNoRemAddr)
108         }
109     }
110 
111     /// send_to writes a packet with payload p to addr.
112     /// send_to can be made to time out and return
send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize>113     async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> {
114         let src_ip = {
115             let obs = self.obs.lock().await;
116             match obs.determine_source_ip(self.loc_addr.ip(), target.ip()) {
117                 Some(ip) => ip,
118                 None => return Err(Error::ErrLocAddr),
119             }
120         };
121 
122         let src_addr = SocketAddr::new(src_ip, self.loc_addr.port());
123 
124         let mut chunk = ChunkUdp::new(src_addr, target);
125         chunk.user_data = buf.to_vec();
126         {
127             let c: Box<dyn Chunk + Send + Sync> = Box::new(chunk);
128             let obs = self.obs.lock().await;
129             obs.write(c).await?
130         }
131 
132         Ok(buf.len())
133     }
134 
local_addr(&self) -> Result<SocketAddr>135     fn local_addr(&self) -> Result<SocketAddr> {
136         Ok(self.loc_addr)
137     }
138 
remote_addr(&self) -> Option<SocketAddr>139     fn remote_addr(&self) -> Option<SocketAddr> {
140         *self.rem_addr.read()
141     }
142 
close(&self) -> Result<()>143     async fn close(&self) -> Result<()> {
144         if self.closed.load(Ordering::SeqCst) {
145             return Err(Error::ErrAlreadyClosed);
146         }
147         self.closed.store(true, Ordering::SeqCst);
148         {
149             let mut reach_ch = self.read_ch_tx.lock().await;
150             reach_ch.take();
151         }
152         {
153             let obs = self.obs.lock().await;
154             obs.on_closed(self.loc_addr).await;
155         }
156 
157         Ok(())
158     }
159 }
160