xref: /webrtc/mdns/src/conn/mod.rs (revision baa26754)
1 use crate::config::*;
2 use crate::error::*;
3 use crate::message::name::*;
4 use crate::message::{header::*, parser::*, question::*, resource::a::*, resource::*, *};
5 
6 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7 use std::sync::Arc;
8 use std::time::Duration;
9 
10 use core::sync::atomic;
11 use socket2::SockAddr;
12 use tokio::net::{ToSocketAddrs, UdpSocket};
13 use tokio::sync::mpsc;
14 use tokio::sync::Mutex;
15 
16 use util::ifaces;
17 
18 mod conn_test;
19 
20 pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353";
21 
22 const INBOUND_BUFFER_SIZE: usize = 512;
23 const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1);
24 const MAX_MESSAGE_RECORDS: usize = 3;
25 const RESPONSE_TTL: u32 = 120;
26 
27 // Conn represents a mDNS Server
28 pub struct DnsConn {
29     socket: Arc<UdpSocket>,
30     dst_addr: SocketAddr,
31 
32     query_interval: Duration,
33     queries: Arc<Mutex<Vec<Query>>>,
34 
35     is_server_closed: Arc<atomic::AtomicBool>,
36     close_server: mpsc::Sender<()>,
37 }
38 
39 struct Query {
40     name_with_suffix: String,
41     query_result_chan: mpsc::Sender<QueryResult>,
42 }
43 
44 struct QueryResult {
45     answer: ResourceHeader,
46     addr: SocketAddr,
47 }
48 
49 impl DnsConn {
50     /// server establishes a mDNS connection over an existing connection
server(addr: SocketAddr, config: Config) -> Result<Self>51     pub fn server(addr: SocketAddr, config: Config) -> Result<Self> {
52         let socket = socket2::Socket::new(
53             socket2::Domain::IPV4,
54             socket2::Type::DGRAM,
55             Some(socket2::Protocol::UDP),
56         )?;
57 
58         #[cfg(feature = "reuse_port")]
59         #[cfg(target_family = "unix")]
60         socket.set_reuse_port(true)?;
61 
62         socket.set_reuse_address(true)?;
63         socket.set_broadcast(true)?;
64         socket.set_nonblocking(true)?;
65 
66         socket.bind(&SockAddr::from(addr))?;
67         {
68             let mut join_error_count = 0;
69             let interfaces = match ifaces::ifaces() {
70                 Ok(e) => e,
71                 Err(e) => {
72                     log::error!("Error getting interfaces: {:?}", e);
73                     return Err(Error::Other(e.to_string()));
74                 }
75             };
76 
77             for interface in &interfaces {
78                 if let Some(SocketAddr::V4(e)) = interface.addr {
79                     if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip())
80                     {
81                         log::trace!("Error connecting multicast, error: {:?}", e);
82                         join_error_count += 1;
83                         continue;
84                     }
85 
86                     log::trace!("Connected to interface address {:?}", e);
87                 }
88             }
89 
90             if join_error_count >= interfaces.len() {
91                 return Err(Error::ErrJoiningMulticastGroup);
92             }
93         }
94 
95         let socket = UdpSocket::from_std(socket.into())?;
96 
97         let local_names = config
98             .local_names
99             .iter()
100             .map(|l| l.to_string() + ".")
101             .collect();
102 
103         let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?;
104 
105         let is_server_closed = Arc::new(atomic::AtomicBool::new(false));
106 
107         let (close_server_send, close_server_rcv) = mpsc::channel(1);
108 
109         let c = DnsConn {
110             query_interval: if config.query_interval != Duration::from_secs(0) {
111                 config.query_interval
112             } else {
113                 DEFAULT_QUERY_INTERVAL
114             },
115 
116             queries: Arc::new(Mutex::new(vec![])),
117             socket: Arc::new(socket),
118             dst_addr,
119             is_server_closed: Arc::clone(&is_server_closed),
120             close_server: close_server_send,
121         };
122 
123         let queries = c.queries.clone();
124         let socket = Arc::clone(&c.socket);
125 
126         tokio::spawn(async move {
127             DnsConn::start(
128                 close_server_rcv,
129                 is_server_closed,
130                 socket,
131                 local_names,
132                 dst_addr,
133                 queries,
134             )
135             .await
136         });
137 
138         Ok(c)
139     }
140 
141     /// Close closes the mDNS Conn
close(&self) -> Result<()>142     pub async fn close(&self) -> Result<()> {
143         log::info!("Closing connection");
144         if self.is_server_closed.load(atomic::Ordering::SeqCst) {
145             return Err(Error::ErrConnectionClosed);
146         }
147 
148         log::trace!("Sending close command to server");
149         match self.close_server.send(()).await {
150             Ok(_) => {
151                 log::trace!("Close command sent");
152                 Ok(())
153             }
154             Err(e) => {
155                 log::warn!("Error sending close command to server: {:?}", e);
156                 Err(Error::ErrConnectionClosed)
157             }
158         }
159     }
160 
161     /// Query sends mDNS Queries for the following name until
162     /// either there's a close signal or we get a result
query( &self, name: &str, mut close_query_signal: mpsc::Receiver<()>, ) -> Result<(ResourceHeader, SocketAddr)>163     pub async fn query(
164         &self,
165         name: &str,
166         mut close_query_signal: mpsc::Receiver<()>,
167     ) -> Result<(ResourceHeader, SocketAddr)> {
168         if self.is_server_closed.load(atomic::Ordering::SeqCst) {
169             return Err(Error::ErrConnectionClosed);
170         }
171 
172         let name_with_suffix = name.to_owned() + ".";
173 
174         let (query_tx, mut query_rx) = mpsc::channel(1);
175         {
176             let mut queries = self.queries.lock().await;
177             queries.push(Query {
178                 name_with_suffix: name_with_suffix.clone(),
179                 query_result_chan: query_tx,
180             });
181         }
182 
183         log::trace!("Sending query");
184         self.send_question(&name_with_suffix).await;
185 
186         loop {
187             tokio::select! {
188                 _ = tokio::time::sleep(self.query_interval) => {
189                     log::trace!("Sending query");
190                     self.send_question(&name_with_suffix).await
191                 },
192 
193                 _ = close_query_signal.recv() => {
194                     log::info!("Query close signal received.");
195                     return Err(Error::ErrConnectionClosed)
196                 },
197 
198                 res_opt = query_rx.recv() =>{
199                     log::info!("Received query result");
200                     if let Some(res) = res_opt{
201                         return Ok((res.answer, res.addr));
202                     }
203                 }
204             }
205         }
206     }
207 
send_question(&self, name: &str)208     async fn send_question(&self, name: &str) {
209         let packed_name = match Name::new(name) {
210             Ok(pn) => pn,
211             Err(err) => {
212                 log::warn!("Failed to construct mDNS packet: {}", err);
213                 return;
214             }
215         };
216 
217         let raw_query = {
218             let mut msg = Message {
219                 header: Header::default(),
220                 questions: vec![Question {
221                     typ: DnsType::A,
222                     class: DNSCLASS_INET,
223                     name: packed_name,
224                 }],
225                 ..Default::default()
226             };
227 
228             match msg.pack() {
229                 Ok(v) => v,
230                 Err(err) => {
231                     log::error!("Failed to construct mDNS packet {}", err);
232                     return;
233                 }
234             }
235         };
236 
237         log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query);
238         if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await {
239             log::error!("Failed to send mDNS packet {}", err);
240         }
241     }
242 
start( mut closed_rx: mpsc::Receiver<()>, close_server: Arc<atomic::AtomicBool>, socket: Arc<UdpSocket>, local_names: Vec<String>, dst_addr: SocketAddr, queries: Arc<Mutex<Vec<Query>>>, ) -> Result<()>243     async fn start(
244         mut closed_rx: mpsc::Receiver<()>,
245         close_server: Arc<atomic::AtomicBool>,
246         socket: Arc<UdpSocket>,
247         local_names: Vec<String>,
248         dst_addr: SocketAddr,
249         queries: Arc<Mutex<Vec<Query>>>,
250     ) -> Result<()> {
251         log::info!("Looping and listening {:?}", socket.local_addr());
252 
253         let mut b = vec![0u8; INBOUND_BUFFER_SIZE];
254         let (mut n, mut src);
255 
256         loop {
257             tokio::select! {
258                 _ = closed_rx.recv() => {
259                     log::info!("Closing server connection");
260                     close_server.store(true, atomic::Ordering::SeqCst);
261 
262                     return Ok(());
263                 }
264 
265                 result = socket.recv_from(&mut b) => {
266                     match result{
267                         Ok((len, addr)) => {
268                             n = len;
269                             src = addr;
270                             log::trace!("Received new connection from {:?}", addr);
271                         },
272 
273                         Err(err) => {
274                             log::error!("Error receiving from socket connection: {:?}", err);
275                             continue;
276                         },
277                     }
278                 }
279             }
280 
281             let mut p = Parser::default();
282             if let Err(err) = p.start(&b[..n]) {
283                 log::error!("Failed to parse mDNS packet {}", err);
284                 continue;
285             }
286 
287             run(&mut p, &socket, &local_names, src, dst_addr, &queries).await
288         }
289     }
290 }
291 
run( p: &mut Parser<'_>, socket: &Arc<UdpSocket>, local_names: &[String], src: SocketAddr, dst_addr: SocketAddr, queries: &Arc<Mutex<Vec<Query>>>, )292 async fn run(
293     p: &mut Parser<'_>,
294     socket: &Arc<UdpSocket>,
295     local_names: &[String],
296     src: SocketAddr,
297     dst_addr: SocketAddr,
298     queries: &Arc<Mutex<Vec<Query>>>,
299 ) {
300     let mut interface_addr = None;
301     for _ in 0..=MAX_MESSAGE_RECORDS {
302         let q = match p.question() {
303             Ok(q) => q,
304             Err(err) => {
305                 if Error::ErrSectionDone == err {
306                     log::trace!("Parsing has completed");
307                     break;
308                 } else {
309                     log::error!("Failed to parse mDNS packet {}", err);
310                     return;
311                 }
312             }
313         };
314 
315         for local_name in local_names {
316             if *local_name == q.name.data {
317                 let interface_addr = match interface_addr {
318                     Some(addr) => addr,
319                     None => match get_interface_addr_for_ip(src).await {
320                         Ok(addr) => {
321                             interface_addr.replace(addr);
322                             addr
323                         }
324                         Err(e) => {
325                             log::warn!(
326                                 "Failed to get local interface to communicate with {}: {:?}",
327                                 &src,
328                                 e
329                             );
330                             continue;
331                         }
332                     },
333                 };
334 
335                 log::trace!(
336                     "Found local name: {} to send answer, IP {}, interface addr {}",
337                     local_name,
338                     src.ip(),
339                     interface_addr
340                 );
341                 if let Err(e) =
342                     send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await
343                 {
344                     log::error!("Error sending answer to client: {:?}", e);
345                     continue;
346                 };
347             }
348         }
349     }
350 
351     for _ in 0..=MAX_MESSAGE_RECORDS {
352         let a = match p.answer_header() {
353             Ok(a) => a,
354             Err(err) => {
355                 if Error::ErrSectionDone != err {
356                     log::warn!("Failed to parse mDNS packet {}", err);
357                 }
358                 return;
359             }
360         };
361 
362         if a.typ != DnsType::A && a.typ != DnsType::Aaaa {
363             continue;
364         }
365 
366         let mut qs = queries.lock().await;
367         for j in (0..qs.len()).rev() {
368             if qs[j].name_with_suffix == a.name.data {
369                 let _ = qs[j]
370                     .query_result_chan
371                     .send(QueryResult {
372                         answer: a.clone(),
373                         addr: src,
374                     })
375                     .await;
376                 qs.remove(j);
377             }
378         }
379     }
380 }
381 
send_answer( socket: &Arc<UdpSocket>, interface_addr: &SocketAddr, name: &str, dst: IpAddr, dst_addr: SocketAddr, ) -> Result<()>382 async fn send_answer(
383     socket: &Arc<UdpSocket>,
384     interface_addr: &SocketAddr,
385     name: &str,
386     dst: IpAddr,
387     dst_addr: SocketAddr,
388 ) -> Result<()> {
389     let raw_answer = {
390         let mut msg = Message {
391             header: Header {
392                 response: true,
393                 authoritative: true,
394                 ..Default::default()
395             },
396 
397             answers: vec![Resource {
398                 header: ResourceHeader {
399                     typ: DnsType::A,
400                     class: DNSCLASS_INET,
401                     name: Name::new(name)?,
402                     ttl: RESPONSE_TTL,
403                     ..Default::default()
404                 },
405                 body: Some(Box::new(AResource {
406                     a: match interface_addr.ip() {
407                         IpAddr::V4(ip) => ip.octets(),
408                         IpAddr::V6(_) => {
409                             return Err(Error::Other("Unexpected IpV6 addr".to_owned()))
410                         }
411                     },
412                 })),
413             }],
414             ..Default::default()
415         };
416 
417         msg.pack()?
418     };
419 
420     socket.send_to(&raw_answer, dst_addr).await?;
421     log::trace!("Sent answer to IP {}", dst);
422 
423     Ok(())
424 }
425 
get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr>426 async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
427     let socket = UdpSocket::bind("0.0.0.0:0").await?;
428     socket.connect(addr).await?;
429     socket.local_addr()
430 }
431