1 use crate::config::*;
2 use crate::error::*;
3 use crate::message::name::*;
4 use crate::message::{header::*, parser::*, question::*, resource::a::*, resource::*, *};
5
6 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7 use std::sync::Arc;
8 use std::time::Duration;
9
10 use core::sync::atomic;
11 use socket2::SockAddr;
12 use tokio::net::{ToSocketAddrs, UdpSocket};
13 use tokio::sync::mpsc;
14 use tokio::sync::Mutex;
15
16 use util::ifaces;
17
18 mod conn_test;
19
20 pub const DEFAULT_DEST_ADDR: &str = "224.0.0.251:5353";
21
22 const INBOUND_BUFFER_SIZE: usize = 512;
23 const DEFAULT_QUERY_INTERVAL: Duration = Duration::from_secs(1);
24 const MAX_MESSAGE_RECORDS: usize = 3;
25 const RESPONSE_TTL: u32 = 120;
26
27 // Conn represents a mDNS Server
28 pub struct DnsConn {
29 socket: Arc<UdpSocket>,
30 dst_addr: SocketAddr,
31
32 query_interval: Duration,
33 queries: Arc<Mutex<Vec<Query>>>,
34
35 is_server_closed: Arc<atomic::AtomicBool>,
36 close_server: mpsc::Sender<()>,
37 }
38
39 struct Query {
40 name_with_suffix: String,
41 query_result_chan: mpsc::Sender<QueryResult>,
42 }
43
44 struct QueryResult {
45 answer: ResourceHeader,
46 addr: SocketAddr,
47 }
48
49 impl DnsConn {
50 /// server establishes a mDNS connection over an existing connection
server(addr: SocketAddr, config: Config) -> Result<Self>51 pub fn server(addr: SocketAddr, config: Config) -> Result<Self> {
52 let socket = socket2::Socket::new(
53 socket2::Domain::IPV4,
54 socket2::Type::DGRAM,
55 Some(socket2::Protocol::UDP),
56 )?;
57
58 #[cfg(feature = "reuse_port")]
59 #[cfg(target_family = "unix")]
60 socket.set_reuse_port(true)?;
61
62 socket.set_reuse_address(true)?;
63 socket.set_broadcast(true)?;
64 socket.set_nonblocking(true)?;
65
66 socket.bind(&SockAddr::from(addr))?;
67 {
68 let mut join_error_count = 0;
69 let interfaces = match ifaces::ifaces() {
70 Ok(e) => e,
71 Err(e) => {
72 log::error!("Error getting interfaces: {:?}", e);
73 return Err(Error::Other(e.to_string()));
74 }
75 };
76
77 for interface in &interfaces {
78 if let Some(SocketAddr::V4(e)) = interface.addr {
79 if let Err(e) = socket.join_multicast_v4(&Ipv4Addr::new(224, 0, 0, 251), e.ip())
80 {
81 log::trace!("Error connecting multicast, error: {:?}", e);
82 join_error_count += 1;
83 continue;
84 }
85
86 log::trace!("Connected to interface address {:?}", e);
87 }
88 }
89
90 if join_error_count >= interfaces.len() {
91 return Err(Error::ErrJoiningMulticastGroup);
92 }
93 }
94
95 let socket = UdpSocket::from_std(socket.into())?;
96
97 let local_names = config
98 .local_names
99 .iter()
100 .map(|l| l.to_string() + ".")
101 .collect();
102
103 let dst_addr: SocketAddr = DEFAULT_DEST_ADDR.parse()?;
104
105 let is_server_closed = Arc::new(atomic::AtomicBool::new(false));
106
107 let (close_server_send, close_server_rcv) = mpsc::channel(1);
108
109 let c = DnsConn {
110 query_interval: if config.query_interval != Duration::from_secs(0) {
111 config.query_interval
112 } else {
113 DEFAULT_QUERY_INTERVAL
114 },
115
116 queries: Arc::new(Mutex::new(vec![])),
117 socket: Arc::new(socket),
118 dst_addr,
119 is_server_closed: Arc::clone(&is_server_closed),
120 close_server: close_server_send,
121 };
122
123 let queries = c.queries.clone();
124 let socket = Arc::clone(&c.socket);
125
126 tokio::spawn(async move {
127 DnsConn::start(
128 close_server_rcv,
129 is_server_closed,
130 socket,
131 local_names,
132 dst_addr,
133 queries,
134 )
135 .await
136 });
137
138 Ok(c)
139 }
140
141 /// Close closes the mDNS Conn
close(&self) -> Result<()>142 pub async fn close(&self) -> Result<()> {
143 log::info!("Closing connection");
144 if self.is_server_closed.load(atomic::Ordering::SeqCst) {
145 return Err(Error::ErrConnectionClosed);
146 }
147
148 log::trace!("Sending close command to server");
149 match self.close_server.send(()).await {
150 Ok(_) => {
151 log::trace!("Close command sent");
152 Ok(())
153 }
154 Err(e) => {
155 log::warn!("Error sending close command to server: {:?}", e);
156 Err(Error::ErrConnectionClosed)
157 }
158 }
159 }
160
161 /// Query sends mDNS Queries for the following name until
162 /// either there's a close signal or we get a result
query( &self, name: &str, mut close_query_signal: mpsc::Receiver<()>, ) -> Result<(ResourceHeader, SocketAddr)>163 pub async fn query(
164 &self,
165 name: &str,
166 mut close_query_signal: mpsc::Receiver<()>,
167 ) -> Result<(ResourceHeader, SocketAddr)> {
168 if self.is_server_closed.load(atomic::Ordering::SeqCst) {
169 return Err(Error::ErrConnectionClosed);
170 }
171
172 let name_with_suffix = name.to_owned() + ".";
173
174 let (query_tx, mut query_rx) = mpsc::channel(1);
175 {
176 let mut queries = self.queries.lock().await;
177 queries.push(Query {
178 name_with_suffix: name_with_suffix.clone(),
179 query_result_chan: query_tx,
180 });
181 }
182
183 log::trace!("Sending query");
184 self.send_question(&name_with_suffix).await;
185
186 loop {
187 tokio::select! {
188 _ = tokio::time::sleep(self.query_interval) => {
189 log::trace!("Sending query");
190 self.send_question(&name_with_suffix).await
191 },
192
193 _ = close_query_signal.recv() => {
194 log::info!("Query close signal received.");
195 return Err(Error::ErrConnectionClosed)
196 },
197
198 res_opt = query_rx.recv() =>{
199 log::info!("Received query result");
200 if let Some(res) = res_opt{
201 return Ok((res.answer, res.addr));
202 }
203 }
204 }
205 }
206 }
207
send_question(&self, name: &str)208 async fn send_question(&self, name: &str) {
209 let packed_name = match Name::new(name) {
210 Ok(pn) => pn,
211 Err(err) => {
212 log::warn!("Failed to construct mDNS packet: {}", err);
213 return;
214 }
215 };
216
217 let raw_query = {
218 let mut msg = Message {
219 header: Header::default(),
220 questions: vec![Question {
221 typ: DnsType::A,
222 class: DNSCLASS_INET,
223 name: packed_name,
224 }],
225 ..Default::default()
226 };
227
228 match msg.pack() {
229 Ok(v) => v,
230 Err(err) => {
231 log::error!("Failed to construct mDNS packet {}", err);
232 return;
233 }
234 }
235 };
236
237 log::trace!("{:?} sending {:?}...", self.socket.local_addr(), raw_query);
238 if let Err(err) = self.socket.send_to(&raw_query, self.dst_addr).await {
239 log::error!("Failed to send mDNS packet {}", err);
240 }
241 }
242
start( mut closed_rx: mpsc::Receiver<()>, close_server: Arc<atomic::AtomicBool>, socket: Arc<UdpSocket>, local_names: Vec<String>, dst_addr: SocketAddr, queries: Arc<Mutex<Vec<Query>>>, ) -> Result<()>243 async fn start(
244 mut closed_rx: mpsc::Receiver<()>,
245 close_server: Arc<atomic::AtomicBool>,
246 socket: Arc<UdpSocket>,
247 local_names: Vec<String>,
248 dst_addr: SocketAddr,
249 queries: Arc<Mutex<Vec<Query>>>,
250 ) -> Result<()> {
251 log::info!("Looping and listening {:?}", socket.local_addr());
252
253 let mut b = vec![0u8; INBOUND_BUFFER_SIZE];
254 let (mut n, mut src);
255
256 loop {
257 tokio::select! {
258 _ = closed_rx.recv() => {
259 log::info!("Closing server connection");
260 close_server.store(true, atomic::Ordering::SeqCst);
261
262 return Ok(());
263 }
264
265 result = socket.recv_from(&mut b) => {
266 match result{
267 Ok((len, addr)) => {
268 n = len;
269 src = addr;
270 log::trace!("Received new connection from {:?}", addr);
271 },
272
273 Err(err) => {
274 log::error!("Error receiving from socket connection: {:?}", err);
275 continue;
276 },
277 }
278 }
279 }
280
281 let mut p = Parser::default();
282 if let Err(err) = p.start(&b[..n]) {
283 log::error!("Failed to parse mDNS packet {}", err);
284 continue;
285 }
286
287 run(&mut p, &socket, &local_names, src, dst_addr, &queries).await
288 }
289 }
290 }
291
run( p: &mut Parser<'_>, socket: &Arc<UdpSocket>, local_names: &[String], src: SocketAddr, dst_addr: SocketAddr, queries: &Arc<Mutex<Vec<Query>>>, )292 async fn run(
293 p: &mut Parser<'_>,
294 socket: &Arc<UdpSocket>,
295 local_names: &[String],
296 src: SocketAddr,
297 dst_addr: SocketAddr,
298 queries: &Arc<Mutex<Vec<Query>>>,
299 ) {
300 let mut interface_addr = None;
301 for _ in 0..=MAX_MESSAGE_RECORDS {
302 let q = match p.question() {
303 Ok(q) => q,
304 Err(err) => {
305 if Error::ErrSectionDone == err {
306 log::trace!("Parsing has completed");
307 break;
308 } else {
309 log::error!("Failed to parse mDNS packet {}", err);
310 return;
311 }
312 }
313 };
314
315 for local_name in local_names {
316 if *local_name == q.name.data {
317 let interface_addr = match interface_addr {
318 Some(addr) => addr,
319 None => match get_interface_addr_for_ip(src).await {
320 Ok(addr) => {
321 interface_addr.replace(addr);
322 addr
323 }
324 Err(e) => {
325 log::warn!(
326 "Failed to get local interface to communicate with {}: {:?}",
327 &src,
328 e
329 );
330 continue;
331 }
332 },
333 };
334
335 log::trace!(
336 "Found local name: {} to send answer, IP {}, interface addr {}",
337 local_name,
338 src.ip(),
339 interface_addr
340 );
341 if let Err(e) =
342 send_answer(socket, &interface_addr, &q.name.data, src.ip(), dst_addr).await
343 {
344 log::error!("Error sending answer to client: {:?}", e);
345 continue;
346 };
347 }
348 }
349 }
350
351 for _ in 0..=MAX_MESSAGE_RECORDS {
352 let a = match p.answer_header() {
353 Ok(a) => a,
354 Err(err) => {
355 if Error::ErrSectionDone != err {
356 log::warn!("Failed to parse mDNS packet {}", err);
357 }
358 return;
359 }
360 };
361
362 if a.typ != DnsType::A && a.typ != DnsType::Aaaa {
363 continue;
364 }
365
366 let mut qs = queries.lock().await;
367 for j in (0..qs.len()).rev() {
368 if qs[j].name_with_suffix == a.name.data {
369 let _ = qs[j]
370 .query_result_chan
371 .send(QueryResult {
372 answer: a.clone(),
373 addr: src,
374 })
375 .await;
376 qs.remove(j);
377 }
378 }
379 }
380 }
381
send_answer( socket: &Arc<UdpSocket>, interface_addr: &SocketAddr, name: &str, dst: IpAddr, dst_addr: SocketAddr, ) -> Result<()>382 async fn send_answer(
383 socket: &Arc<UdpSocket>,
384 interface_addr: &SocketAddr,
385 name: &str,
386 dst: IpAddr,
387 dst_addr: SocketAddr,
388 ) -> Result<()> {
389 let raw_answer = {
390 let mut msg = Message {
391 header: Header {
392 response: true,
393 authoritative: true,
394 ..Default::default()
395 },
396
397 answers: vec![Resource {
398 header: ResourceHeader {
399 typ: DnsType::A,
400 class: DNSCLASS_INET,
401 name: Name::new(name)?,
402 ttl: RESPONSE_TTL,
403 ..Default::default()
404 },
405 body: Some(Box::new(AResource {
406 a: match interface_addr.ip() {
407 IpAddr::V4(ip) => ip.octets(),
408 IpAddr::V6(_) => {
409 return Err(Error::Other("Unexpected IpV6 addr".to_owned()))
410 }
411 },
412 })),
413 }],
414 ..Default::default()
415 };
416
417 msg.pack()?
418 };
419
420 socket.send_to(&raw_answer, dst_addr).await?;
421 log::trace!("Sent answer to IP {}", dst);
422
423 Ok(())
424 }
425
get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr>426 async fn get_interface_addr_for_ip(addr: impl ToSocketAddrs) -> std::io::Result<SocketAddr> {
427 let socket = UdpSocket::bind("0.0.0.0:0").await?;
428 socket.connect(addr).await?;
429 socket.local_addr()
430 }
431