xref: /webrtc/ice/src/util/mod.rs (revision 0c35e613)
1 #[cfg(test)]
2 mod util_test;
3 
4 use crate::agent::agent_config::{InterfaceFilterFn, IpFilterFn};
5 use crate::error::*;
6 use crate::network_type::*;
7 
8 use std::collections::HashSet;
9 use std::net::{IpAddr, SocketAddr};
10 use std::sync::Arc;
11 use stun::{agent::*, attributes::*, integrity::*, message::*, textattrs::*, xoraddr::*};
12 use tokio::time::Duration;
13 use util::{vnet::net::*, Conn};
14 
create_addr(_network: NetworkType, ip: IpAddr, port: u16) -> SocketAddr15 pub fn create_addr(_network: NetworkType, ip: IpAddr, port: u16) -> SocketAddr {
16     /*if network.is_tcp(){
17         return &net.TCPAddr{IP: ip, Port: port}
18     default:
19         return &net.UDPAddr{IP: ip, Port: port}
20     }*/
21     SocketAddr::new(ip, port)
22 }
23 
assert_inbound_username(m: &Message, expected_username: &str) -> Result<()>24 pub fn assert_inbound_username(m: &Message, expected_username: &str) -> Result<()> {
25     let mut username = Username::new(ATTR_USERNAME, String::new());
26     username.get_from(m)?;
27 
28     if username.to_string() != expected_username {
29         return Err(Error::Other(format!(
30             "{:?} expected({}) actual({})",
31             Error::ErrMismatchUsername,
32             expected_username,
33             username,
34         )));
35     }
36 
37     Ok(())
38 }
39 
assert_inbound_message_integrity(m: &mut Message, key: &[u8]) -> Result<()>40 pub fn assert_inbound_message_integrity(m: &mut Message, key: &[u8]) -> Result<()> {
41     let message_integrity_attr = MessageIntegrity(key.to_vec());
42     Ok(message_integrity_attr.check(m)?)
43 }
44 
45 /// Initiates a stun requests to `server_addr` using conn, reads the response and returns the
46 /// `XORMappedAddress` returned by the stun server.
47 /// Adapted from stun v0.2.
get_xormapped_addr( conn: &Arc<dyn Conn + Send + Sync>, server_addr: SocketAddr, deadline: Duration, ) -> Result<XorMappedAddress>48 pub async fn get_xormapped_addr(
49     conn: &Arc<dyn Conn + Send + Sync>,
50     server_addr: SocketAddr,
51     deadline: Duration,
52 ) -> Result<XorMappedAddress> {
53     let resp = stun_request(conn, server_addr, deadline).await?;
54     let mut addr = XorMappedAddress::default();
55     addr.get_from(&resp)?;
56     Ok(addr)
57 }
58 
59 const MAX_MESSAGE_SIZE: usize = 1280;
60 
stun_request( conn: &Arc<dyn Conn + Send + Sync>, server_addr: SocketAddr, deadline: Duration, ) -> Result<Message>61 pub async fn stun_request(
62     conn: &Arc<dyn Conn + Send + Sync>,
63     server_addr: SocketAddr,
64     deadline: Duration,
65 ) -> Result<Message> {
66     let mut request = Message::new();
67     request.build(&[Box::new(BINDING_REQUEST), Box::new(TransactionId::new())])?;
68 
69     conn.send_to(&request.raw, server_addr).await?;
70     let mut bs = vec![0_u8; MAX_MESSAGE_SIZE];
71     let (n, _) = if deadline > Duration::from_secs(0) {
72         match tokio::time::timeout(deadline, conn.recv_from(&mut bs)).await {
73             Ok(result) => match result {
74                 Ok((n, addr)) => (n, addr),
75                 Err(err) => return Err(Error::Other(err.to_string())),
76             },
77             Err(err) => return Err(Error::Other(err.to_string())),
78         }
79     } else {
80         conn.recv_from(&mut bs).await?
81     };
82 
83     let mut res = Message::new();
84     res.raw = bs[..n].to_vec();
85     res.decode()?;
86 
87     Ok(res)
88 }
89 
local_interfaces( vnet: &Arc<Net>, interface_filter: &Option<InterfaceFilterFn>, ip_filter: &Option<IpFilterFn>, network_types: &[NetworkType], ) -> HashSet<IpAddr>90 pub async fn local_interfaces(
91     vnet: &Arc<Net>,
92     interface_filter: &Option<InterfaceFilterFn>,
93     ip_filter: &Option<IpFilterFn>,
94     network_types: &[NetworkType],
95 ) -> HashSet<IpAddr> {
96     let mut ips = HashSet::new();
97     let interfaces = vnet.get_interfaces().await;
98 
99     let (mut ipv4requested, mut ipv6requested) = (false, false);
100     for typ in network_types {
101         if typ.is_ipv4() {
102             ipv4requested = true;
103         }
104         if typ.is_ipv6() {
105             ipv6requested = true;
106         }
107     }
108 
109     for iface in interfaces {
110         if let Some(filter) = interface_filter {
111             if !filter(iface.name()) {
112                 continue;
113             }
114         }
115 
116         for ipnet in iface.addrs() {
117             let ipaddr = ipnet.addr();
118 
119             if !ipaddr.is_loopback()
120                 && ((ipv4requested && ipaddr.is_ipv4()) || (ipv6requested && ipaddr.is_ipv6()))
121                 && ip_filter
122                     .as_ref()
123                     .map(|filter| filter(ipaddr))
124                     .unwrap_or(true)
125             {
126                 ips.insert(ipaddr);
127             }
128         }
129     }
130 
131     ips
132 }
133 
listen_udp_in_port_range( vnet: &Arc<Net>, port_max: u16, port_min: u16, laddr: SocketAddr, ) -> Result<Arc<dyn Conn + Send + Sync>>134 pub async fn listen_udp_in_port_range(
135     vnet: &Arc<Net>,
136     port_max: u16,
137     port_min: u16,
138     laddr: SocketAddr,
139 ) -> Result<Arc<dyn Conn + Send + Sync>> {
140     if laddr.port() != 0 || (port_min == 0 && port_max == 0) {
141         return Ok(vnet.bind(laddr).await?);
142     }
143     let i = if port_min == 0 { 1 } else { port_min };
144     let j = if port_max == 0 { 0xFFFF } else { port_max };
145     if i > j {
146         return Err(Error::ErrPort);
147     }
148 
149     let port_start = rand::random::<u16>() % (j - i + 1) + i;
150     let mut port_current = port_start;
151     loop {
152         let laddr = SocketAddr::new(laddr.ip(), port_current);
153         match vnet.bind(laddr).await {
154             Ok(c) => return Ok(c),
155             Err(err) => log::debug!("failed to listen {}: {}", laddr, err),
156         };
157 
158         port_current += 1;
159         if port_current > j {
160             port_current = i;
161         }
162         if port_current == port_start {
163             break;
164         }
165     }
166 
167     Err(Error::ErrPort)
168 }
169