xref: /webrtc/ice/src/agent/agent_internal.rs (revision 5b79f08a)
1 use super::agent_transport::*;
2 use super::*;
3 use crate::candidate::candidate_base::CandidateBaseConfig;
4 use crate::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig;
5 use crate::util::*;
6 use arc_swap::ArcSwapOption;
7 use std::sync::atomic::{AtomicBool, AtomicU64};
8 use util::sync::Mutex as SyncMutex;
9 
10 pub type ChanCandidateTx =
11     Arc<Mutex<Option<mpsc::Sender<Option<Arc<dyn Candidate + Send + Sync>>>>>>;
12 
13 #[derive(Default)]
14 pub(crate) struct UfragPwd {
15     pub(crate) local_ufrag: String,
16     pub(crate) local_pwd: String,
17     pub(crate) remote_ufrag: String,
18     pub(crate) remote_pwd: String,
19 }
20 
21 pub struct AgentInternal {
22     // State owned by the taskLoop
23     pub(crate) on_connected_tx: Mutex<Option<mpsc::Sender<()>>>,
24     pub(crate) on_connected_rx: Mutex<Option<mpsc::Receiver<()>>>,
25 
26     // State for closing
27     pub(crate) done_tx: Mutex<Option<mpsc::Sender<()>>>,
28     // force candidate to be contacted immediately (instead of waiting for task ticker)
29     pub(crate) force_candidate_contact_tx: mpsc::Sender<bool>,
30     pub(crate) done_and_force_candidate_contact_rx:
31         Mutex<Option<(mpsc::Receiver<()>, mpsc::Receiver<bool>)>>,
32 
33     pub(crate) chan_candidate_tx: ChanCandidateTx,
34     pub(crate) chan_candidate_pair_tx: Mutex<Option<mpsc::Sender<()>>>,
35     pub(crate) chan_state_tx: Mutex<Option<mpsc::Sender<ConnectionState>>>,
36 
37     pub(crate) on_connection_state_change_hdlr: ArcSwapOption<Mutex<OnConnectionStateChangeHdlrFn>>,
38     pub(crate) on_selected_candidate_pair_change_hdlr:
39         ArcSwapOption<Mutex<OnSelectedCandidatePairChangeHdlrFn>>,
40     pub(crate) on_candidate_hdlr: ArcSwapOption<Mutex<OnCandidateHdlrFn>>,
41 
42     pub(crate) tie_breaker: AtomicU64,
43     pub(crate) is_controlling: AtomicBool,
44     pub(crate) lite: AtomicBool,
45 
46     pub(crate) start_time: SyncMutex<Instant>,
47     pub(crate) nominated_pair: Mutex<Option<Arc<CandidatePair>>>,
48 
49     pub(crate) connection_state: AtomicU8, //ConnectionState,
50 
51     pub(crate) started_ch_tx: Mutex<Option<broadcast::Sender<()>>>,
52 
53     pub(crate) ufrag_pwd: Mutex<UfragPwd>,
54 
55     pub(crate) local_candidates: Mutex<HashMap<NetworkType, Vec<Arc<dyn Candidate + Send + Sync>>>>,
56     pub(crate) remote_candidates:
57         Mutex<HashMap<NetworkType, Vec<Arc<dyn Candidate + Send + Sync>>>>,
58 
59     // LRU of outbound Binding request Transaction IDs
60     pub(crate) pending_binding_requests: Mutex<Vec<BindingRequest>>,
61 
62     pub(crate) agent_conn: Arc<AgentConn>,
63 
64     // the following variables won't be changed after init_with_defaults()
65     pub(crate) insecure_skip_verify: bool,
66     pub(crate) max_binding_requests: u16,
67     pub(crate) host_acceptance_min_wait: Duration,
68     pub(crate) srflx_acceptance_min_wait: Duration,
69     pub(crate) prflx_acceptance_min_wait: Duration,
70     pub(crate) relay_acceptance_min_wait: Duration,
71     // How long connectivity checks can fail before the ICE Agent
72     // goes to disconnected
73     pub(crate) disconnected_timeout: Duration,
74     // How long connectivity checks can fail before the ICE Agent
75     // goes to failed
76     pub(crate) failed_timeout: Duration,
77     // How often should we send keepalive packets?
78     // 0 means never
79     pub(crate) keepalive_interval: Duration,
80     // How often should we run our internal taskLoop to check for state changes when connecting
81     pub(crate) check_interval: Duration,
82 }
83 
84 impl AgentInternal {
new(config: &AgentConfig) -> (Self, ChanReceivers)85     pub(super) fn new(config: &AgentConfig) -> (Self, ChanReceivers) {
86         let (chan_state_tx, chan_state_rx) = mpsc::channel(1);
87         let (chan_candidate_tx, chan_candidate_rx) = mpsc::channel(1);
88         let (chan_candidate_pair_tx, chan_candidate_pair_rx) = mpsc::channel(1);
89         let (on_connected_tx, on_connected_rx) = mpsc::channel(1);
90         let (done_tx, done_rx) = mpsc::channel(1);
91         let (force_candidate_contact_tx, force_candidate_contact_rx) = mpsc::channel(1);
92         let (started_ch_tx, _) = broadcast::channel(1);
93 
94         let ai = AgentInternal {
95             on_connected_tx: Mutex::new(Some(on_connected_tx)),
96             on_connected_rx: Mutex::new(Some(on_connected_rx)),
97 
98             done_tx: Mutex::new(Some(done_tx)),
99             force_candidate_contact_tx,
100             done_and_force_candidate_contact_rx: Mutex::new(Some((
101                 done_rx,
102                 force_candidate_contact_rx,
103             ))),
104 
105             chan_candidate_tx: Arc::new(Mutex::new(Some(chan_candidate_tx))),
106             chan_candidate_pair_tx: Mutex::new(Some(chan_candidate_pair_tx)),
107             chan_state_tx: Mutex::new(Some(chan_state_tx)),
108 
109             on_connection_state_change_hdlr: ArcSwapOption::empty(),
110             on_selected_candidate_pair_change_hdlr: ArcSwapOption::empty(),
111             on_candidate_hdlr: ArcSwapOption::empty(),
112 
113             tie_breaker: AtomicU64::new(rand::random::<u64>()),
114             is_controlling: AtomicBool::new(config.is_controlling),
115             lite: AtomicBool::new(config.lite),
116 
117             start_time: SyncMutex::new(Instant::now()),
118             nominated_pair: Mutex::new(None),
119 
120             connection_state: AtomicU8::new(ConnectionState::New as u8),
121 
122             insecure_skip_verify: config.insecure_skip_verify,
123 
124             started_ch_tx: Mutex::new(Some(started_ch_tx)),
125 
126             //won't change after init_with_defaults()
127             max_binding_requests: 0,
128             host_acceptance_min_wait: Duration::from_secs(0),
129             srflx_acceptance_min_wait: Duration::from_secs(0),
130             prflx_acceptance_min_wait: Duration::from_secs(0),
131             relay_acceptance_min_wait: Duration::from_secs(0),
132 
133             // How long connectivity checks can fail before the ICE Agent
134             // goes to disconnected
135             disconnected_timeout: Duration::from_secs(0),
136 
137             // How long connectivity checks can fail before the ICE Agent
138             // goes to failed
139             failed_timeout: Duration::from_secs(0),
140 
141             // How often should we send keepalive packets?
142             // 0 means never
143             keepalive_interval: Duration::from_secs(0),
144 
145             // How often should we run our internal taskLoop to check for state changes when connecting
146             check_interval: Duration::from_secs(0),
147 
148             ufrag_pwd: Mutex::new(UfragPwd::default()),
149 
150             local_candidates: Mutex::new(HashMap::new()),
151             remote_candidates: Mutex::new(HashMap::new()),
152 
153             // LRU of outbound Binding request Transaction IDs
154             pending_binding_requests: Mutex::new(vec![]),
155 
156             // AgentConn
157             agent_conn: Arc::new(AgentConn::new()),
158         };
159 
160         let chan_receivers = ChanReceivers {
161             chan_state_rx,
162             chan_candidate_rx,
163             chan_candidate_pair_rx,
164         };
165         (ai, chan_receivers)
166     }
start_connectivity_checks( self: &Arc<Self>, is_controlling: bool, remote_ufrag: String, remote_pwd: String, ) -> Result<()>167     pub(crate) async fn start_connectivity_checks(
168         self: &Arc<Self>,
169         is_controlling: bool,
170         remote_ufrag: String,
171         remote_pwd: String,
172     ) -> Result<()> {
173         {
174             let started_ch_tx = self.started_ch_tx.lock().await;
175             if started_ch_tx.is_none() {
176                 return Err(Error::ErrMultipleStart);
177             }
178         }
179 
180         log::debug!(
181             "Started agent: isControlling? {}, remoteUfrag: {}, remotePwd: {}",
182             is_controlling,
183             remote_ufrag,
184             remote_pwd
185         );
186         self.set_remote_credentials(remote_ufrag, remote_pwd)
187             .await?;
188         self.is_controlling.store(is_controlling, Ordering::SeqCst);
189         self.start().await;
190         {
191             let mut started_ch_tx = self.started_ch_tx.lock().await;
192             started_ch_tx.take();
193         }
194 
195         self.update_connection_state(ConnectionState::Checking)
196             .await;
197 
198         self.request_connectivity_check();
199 
200         self.connectivity_checks().await;
201 
202         Ok(())
203     }
204 
contact( &self, last_connection_state: &mut ConnectionState, checking_duration: &mut Instant, )205     async fn contact(
206         &self,
207         last_connection_state: &mut ConnectionState,
208         checking_duration: &mut Instant,
209     ) {
210         if self.connection_state.load(Ordering::SeqCst) == ConnectionState::Failed as u8 {
211             // The connection is currently failed so don't send any checks
212             // In the future it may be restarted though
213             *last_connection_state = self.connection_state.load(Ordering::SeqCst).into();
214             return;
215         }
216         if self.connection_state.load(Ordering::SeqCst) == ConnectionState::Checking as u8 {
217             // We have just entered checking for the first time so update our checking timer
218             if *last_connection_state as u8 != self.connection_state.load(Ordering::SeqCst) {
219                 *checking_duration = Instant::now();
220             }
221 
222             // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed
223             if Instant::now()
224                 .checked_duration_since(*checking_duration)
225                 .unwrap_or_else(|| Duration::from_secs(0))
226                 > self.disconnected_timeout + self.failed_timeout
227             {
228                 self.update_connection_state(ConnectionState::Failed).await;
229                 *last_connection_state = self.connection_state.load(Ordering::SeqCst).into();
230                 return;
231             }
232         }
233 
234         self.contact_candidates().await;
235 
236         *last_connection_state = self.connection_state.load(Ordering::SeqCst).into();
237     }
238 
connectivity_checks(self: &Arc<Self>)239     async fn connectivity_checks(self: &Arc<Self>) {
240         const ZERO_DURATION: Duration = Duration::from_secs(0);
241         let mut last_connection_state = ConnectionState::Unspecified;
242         let mut checking_duration = Instant::now();
243         let (check_interval, keepalive_interval, disconnected_timeout, failed_timeout) = (
244             self.check_interval,
245             self.keepalive_interval,
246             self.disconnected_timeout,
247             self.failed_timeout,
248         );
249 
250         let done_and_force_candidate_contact_rx = {
251             let mut done_and_force_candidate_contact_rx =
252                 self.done_and_force_candidate_contact_rx.lock().await;
253             done_and_force_candidate_contact_rx.take()
254         };
255 
256         if let Some((mut done_rx, mut force_candidate_contact_rx)) =
257             done_and_force_candidate_contact_rx
258         {
259             let ai = Arc::clone(self);
260             tokio::spawn(async move {
261                 loop {
262                     let mut interval = DEFAULT_CHECK_INTERVAL;
263 
264                     let mut update_interval = |x: Duration| {
265                         if x != ZERO_DURATION && (interval == ZERO_DURATION || interval > x) {
266                             interval = x;
267                         }
268                     };
269 
270                     match last_connection_state {
271                         ConnectionState::New | ConnectionState::Checking => {
272                             // While connecting, check candidates more frequently
273                             update_interval(check_interval);
274                         }
275                         ConnectionState::Connected | ConnectionState::Disconnected => {
276                             update_interval(keepalive_interval);
277                         }
278                         _ => {}
279                     };
280                     // Ensure we run our task loop as quickly as the minimum of our various configured timeouts
281                     update_interval(disconnected_timeout);
282                     update_interval(failed_timeout);
283 
284                     let t = tokio::time::sleep(interval);
285                     tokio::pin!(t);
286 
287                     tokio::select! {
288                         _ = t.as_mut() => {
289                             ai.contact(&mut last_connection_state, &mut checking_duration).await;
290                         },
291                         _ = force_candidate_contact_rx.recv() => {
292                             ai.contact(&mut last_connection_state, &mut checking_duration).await;
293                         },
294                         _ = done_rx.recv() => {
295                             return;
296                         }
297                     }
298                 }
299             });
300         }
301     }
302 
update_connection_state(&self, new_state: ConnectionState)303     pub(crate) async fn update_connection_state(&self, new_state: ConnectionState) {
304         if self.connection_state.load(Ordering::SeqCst) != new_state as u8 {
305             // Connection has gone to failed, release all gathered candidates
306             if new_state == ConnectionState::Failed {
307                 self.delete_all_candidates().await;
308             }
309 
310             log::info!(
311                 "[{}]: Setting new connection state: {}",
312                 self.get_name(),
313                 new_state
314             );
315             self.connection_state
316                 .store(new_state as u8, Ordering::SeqCst);
317 
318             // Call handler after finishing current task since we may be holding the agent lock
319             // and the handler may also require it
320             {
321                 let chan_state_tx = self.chan_state_tx.lock().await;
322                 if let Some(tx) = &*chan_state_tx {
323                     let _ = tx.send(new_state).await;
324                 }
325             }
326         }
327     }
328 
set_selected_pair(&self, p: Option<Arc<CandidatePair>>)329     pub(crate) async fn set_selected_pair(&self, p: Option<Arc<CandidatePair>>) {
330         log::trace!(
331             "[{}]: Set selected candidate pair: {:?}",
332             self.get_name(),
333             p
334         );
335 
336         if let Some(p) = p {
337             p.nominated.store(true, Ordering::SeqCst);
338             self.agent_conn.selected_pair.store(Some(p));
339 
340             self.update_connection_state(ConnectionState::Connected)
341                 .await;
342 
343             // Notify when the selected pair changes
344             {
345                 let chan_candidate_pair_tx = self.chan_candidate_pair_tx.lock().await;
346                 if let Some(tx) = &*chan_candidate_pair_tx {
347                     let _ = tx.send(()).await;
348                 }
349             }
350 
351             // Signal connected
352             {
353                 let mut on_connected_tx = self.on_connected_tx.lock().await;
354                 on_connected_tx.take();
355             }
356         } else {
357             self.agent_conn.selected_pair.store(None);
358         }
359     }
360 
ping_all_candidates(&self)361     pub(crate) async fn ping_all_candidates(&self) {
362         log::trace!("[{}]: pinging all candidates", self.get_name(),);
363 
364         let mut pairs: Vec<(
365             Arc<dyn Candidate + Send + Sync>,
366             Arc<dyn Candidate + Send + Sync>,
367         )> = vec![];
368 
369         {
370             let mut checklist = self.agent_conn.checklist.lock().await;
371             if checklist.is_empty() {
372                 log::warn!(
373                     "[{}]: pingAllCandidates called with no candidate pairs. Connection is not possible yet.",
374                     self.get_name(),
375                 );
376             }
377             for p in &mut *checklist {
378                 let p_state = p.state.load(Ordering::SeqCst);
379                 if p_state == CandidatePairState::Waiting as u8 {
380                     p.state
381                         .store(CandidatePairState::InProgress as u8, Ordering::SeqCst);
382                 } else if p_state != CandidatePairState::InProgress as u8 {
383                     continue;
384                 }
385 
386                 if p.binding_request_count.load(Ordering::SeqCst) > self.max_binding_requests {
387                     log::trace!(
388                         "[{}]: max requests reached for pair {}, marking it as failed",
389                         self.get_name(),
390                         p
391                     );
392                     p.state
393                         .store(CandidatePairState::Failed as u8, Ordering::SeqCst);
394                 } else {
395                     p.binding_request_count.fetch_add(1, Ordering::SeqCst);
396                     let local = p.local.clone();
397                     let remote = p.remote.clone();
398                     pairs.push((local, remote));
399                 }
400             }
401         }
402 
403         for (local, remote) in pairs {
404             self.ping_candidate(&local, &remote).await;
405         }
406     }
407 
add_pair( &self, local: Arc<dyn Candidate + Send + Sync>, remote: Arc<dyn Candidate + Send + Sync>, )408     pub(crate) async fn add_pair(
409         &self,
410         local: Arc<dyn Candidate + Send + Sync>,
411         remote: Arc<dyn Candidate + Send + Sync>,
412     ) {
413         let p = Arc::new(CandidatePair::new(
414             local,
415             remote,
416             self.is_controlling.load(Ordering::SeqCst),
417         ));
418         let mut checklist = self.agent_conn.checklist.lock().await;
419         checklist.push(p);
420     }
421 
find_pair( &self, local: &Arc<dyn Candidate + Send + Sync>, remote: &Arc<dyn Candidate + Send + Sync>, ) -> Option<Arc<CandidatePair>>422     pub(crate) async fn find_pair(
423         &self,
424         local: &Arc<dyn Candidate + Send + Sync>,
425         remote: &Arc<dyn Candidate + Send + Sync>,
426     ) -> Option<Arc<CandidatePair>> {
427         let checklist = self.agent_conn.checklist.lock().await;
428         for p in &*checklist {
429             if p.local.equal(&**local) && p.remote.equal(&**remote) {
430                 return Some(p.clone());
431             }
432         }
433         None
434     }
435 
436     /// Checks if the selected pair is (still) valid.
437     /// Note: the caller should hold the agent lock.
validate_selected_pair(&self) -> bool438     pub(crate) async fn validate_selected_pair(&self) -> bool {
439         let (valid, disconnected_time) = {
440             let selected_pair = self.agent_conn.selected_pair.load();
441             (*selected_pair).as_ref().map_or_else(
442                 || (false, Duration::from_secs(0)),
443                 |selected_pair| {
444                     let disconnected_time = SystemTime::now()
445                         .duration_since(selected_pair.remote.last_received())
446                         .unwrap_or_else(|_| Duration::from_secs(0));
447                     (true, disconnected_time)
448                 },
449             )
450         };
451 
452         if valid {
453             // Only allow transitions to failed if a.failedTimeout is non-zero
454             let mut total_time_to_failure = self.failed_timeout;
455             if total_time_to_failure != Duration::from_secs(0) {
456                 total_time_to_failure += self.disconnected_timeout;
457             }
458 
459             if total_time_to_failure != Duration::from_secs(0)
460                 && disconnected_time > total_time_to_failure
461             {
462                 self.update_connection_state(ConnectionState::Failed).await;
463             } else if self.disconnected_timeout != Duration::from_secs(0)
464                 && disconnected_time > self.disconnected_timeout
465             {
466                 self.update_connection_state(ConnectionState::Disconnected)
467                     .await;
468             } else {
469                 self.update_connection_state(ConnectionState::Connected)
470                     .await;
471             }
472         }
473 
474         valid
475     }
476 
477     /// Sends STUN Binding Indications to the selected pair.
478     /// if no packet has been sent on that pair in the last keepaliveInterval.
479     /// Note: the caller should hold the agent lock.
check_keepalive(&self)480     pub(crate) async fn check_keepalive(&self) {
481         let (local, remote) = {
482             let selected_pair = self.agent_conn.selected_pair.load();
483             (*selected_pair)
484                 .as_ref()
485                 .map_or((None, None), |selected_pair| {
486                     (
487                         Some(selected_pair.local.clone()),
488                         Some(selected_pair.remote.clone()),
489                     )
490                 })
491         };
492 
493         if let (Some(local), Some(remote)) = (local, remote) {
494             let last_sent = SystemTime::now()
495                 .duration_since(local.last_sent())
496                 .unwrap_or_else(|_| Duration::from_secs(0));
497 
498             let last_received = SystemTime::now()
499                 .duration_since(remote.last_received())
500                 .unwrap_or_else(|_| Duration::from_secs(0));
501 
502             if (self.keepalive_interval != Duration::from_secs(0))
503                 && ((last_sent > self.keepalive_interval)
504                     || (last_received > self.keepalive_interval))
505             {
506                 // we use binding request instead of indication to support refresh consent schemas
507                 // see https://tools.ietf.org/html/rfc7675
508                 self.ping_candidate(&local, &remote).await;
509             }
510         }
511     }
512 
request_connectivity_check(&self)513     fn request_connectivity_check(&self) {
514         let _ = self.force_candidate_contact_tx.try_send(true);
515     }
516 
517     /// Assumes you are holding the lock (must be execute using a.run).
add_remote_candidate(&self, c: &Arc<dyn Candidate + Send + Sync>)518     pub(crate) async fn add_remote_candidate(&self, c: &Arc<dyn Candidate + Send + Sync>) {
519         let network_type = c.network_type();
520 
521         {
522             let mut remote_candidates = self.remote_candidates.lock().await;
523             if let Some(cands) = remote_candidates.get(&network_type) {
524                 for cand in cands {
525                     if cand.equal(&**c) {
526                         return;
527                     }
528                 }
529             }
530 
531             if let Some(cands) = remote_candidates.get_mut(&network_type) {
532                 cands.push(c.clone());
533             } else {
534                 remote_candidates.insert(network_type, vec![c.clone()]);
535             }
536         }
537 
538         let mut local_cands = vec![];
539         {
540             let local_candidates = self.local_candidates.lock().await;
541             if let Some(cands) = local_candidates.get(&network_type) {
542                 local_cands = cands.clone();
543             }
544         }
545 
546         for cand in local_cands {
547             self.add_pair(cand, c.clone()).await;
548         }
549 
550         self.request_connectivity_check();
551     }
552 
add_candidate( self: &Arc<Self>, c: &Arc<dyn Candidate + Send + Sync>, ) -> Result<()>553     pub(crate) async fn add_candidate(
554         self: &Arc<Self>,
555         c: &Arc<dyn Candidate + Send + Sync>,
556     ) -> Result<()> {
557         let initialized_ch = {
558             let started_ch_tx = self.started_ch_tx.lock().await;
559             (*started_ch_tx).as_ref().map(|tx| tx.subscribe())
560         };
561 
562         self.start_candidate(c, initialized_ch).await;
563 
564         let network_type = c.network_type();
565         {
566             let mut local_candidates = self.local_candidates.lock().await;
567             if let Some(cands) = local_candidates.get(&network_type) {
568                 for cand in cands {
569                     if cand.equal(&**c) {
570                         if let Err(err) = c.close().await {
571                             log::warn!(
572                                 "[{}]: Failed to close duplicate candidate: {}",
573                                 self.get_name(),
574                                 err
575                             );
576                         }
577                         //TODO: why return?
578                         return Ok(());
579                     }
580                 }
581             }
582 
583             if let Some(cands) = local_candidates.get_mut(&network_type) {
584                 cands.push(c.clone());
585             } else {
586                 local_candidates.insert(network_type, vec![c.clone()]);
587             }
588         }
589 
590         let mut remote_cands = vec![];
591         {
592             let remote_candidates = self.remote_candidates.lock().await;
593             if let Some(cands) = remote_candidates.get(&network_type) {
594                 remote_cands = cands.clone();
595             }
596         }
597 
598         for cand in remote_cands {
599             self.add_pair(c.clone(), cand).await;
600         }
601 
602         self.request_connectivity_check();
603         {
604             let chan_candidate_tx = self.chan_candidate_tx.lock().await;
605             if let Some(tx) = &*chan_candidate_tx {
606                 let _ = tx.send(Some(c.clone())).await;
607             }
608         }
609 
610         Ok(())
611     }
612 
close(&self) -> Result<()>613     pub(crate) async fn close(&self) -> Result<()> {
614         {
615             let mut done_tx = self.done_tx.lock().await;
616             if done_tx.is_none() {
617                 return Err(Error::ErrClosed);
618             }
619             done_tx.take();
620         };
621         self.delete_all_candidates().await;
622         {
623             let mut started_ch_tx = self.started_ch_tx.lock().await;
624             started_ch_tx.take();
625         }
626 
627         self.agent_conn.buffer.close().await;
628 
629         self.update_connection_state(ConnectionState::Closed).await;
630 
631         {
632             let mut chan_candidate_tx = self.chan_candidate_tx.lock().await;
633             chan_candidate_tx.take();
634         }
635         {
636             let mut chan_candidate_pair_tx = self.chan_candidate_pair_tx.lock().await;
637             chan_candidate_pair_tx.take();
638         }
639         {
640             let mut chan_state_tx = self.chan_state_tx.lock().await;
641             chan_state_tx.take();
642         }
643 
644         self.agent_conn.done.store(true, Ordering::SeqCst);
645 
646         Ok(())
647     }
648 
649     /// Remove all candidates.
650     /// This closes any listening sockets and removes both the local and remote candidate lists.
651     ///
652     /// This is used for restarts, failures and on close.
delete_all_candidates(&self)653     pub(crate) async fn delete_all_candidates(&self) {
654         {
655             let mut local_candidates = self.local_candidates.lock().await;
656             for cs in local_candidates.values_mut() {
657                 for c in cs {
658                     if let Err(err) = c.close().await {
659                         log::warn!(
660                             "[{}]: Failed to close candidate {}: {}",
661                             self.get_name(),
662                             c,
663                             err
664                         );
665                     }
666                 }
667             }
668             local_candidates.clear();
669         }
670 
671         {
672             let mut remote_candidates = self.remote_candidates.lock().await;
673             for cs in remote_candidates.values_mut() {
674                 for c in cs {
675                     if let Err(err) = c.close().await {
676                         log::warn!(
677                             "[{}]: Failed to close candidate {}: {}",
678                             self.get_name(),
679                             c,
680                             err
681                         );
682                     }
683                 }
684             }
685             remote_candidates.clear();
686         }
687     }
688 
find_remote_candidate( &self, network_type: NetworkType, addr: SocketAddr, ) -> Option<Arc<dyn Candidate + Send + Sync>>689     pub(crate) async fn find_remote_candidate(
690         &self,
691         network_type: NetworkType,
692         addr: SocketAddr,
693     ) -> Option<Arc<dyn Candidate + Send + Sync>> {
694         let (ip, port) = (addr.ip(), addr.port());
695 
696         let remote_candidates = self.remote_candidates.lock().await;
697         if let Some(cands) = remote_candidates.get(&network_type) {
698             for c in cands {
699                 if c.address() == ip.to_string() && c.port() == port {
700                     return Some(c.clone());
701                 }
702             }
703         }
704         None
705     }
706 
send_binding_request( &self, m: &Message, local: &Arc<dyn Candidate + Send + Sync>, remote: &Arc<dyn Candidate + Send + Sync>, )707     pub(crate) async fn send_binding_request(
708         &self,
709         m: &Message,
710         local: &Arc<dyn Candidate + Send + Sync>,
711         remote: &Arc<dyn Candidate + Send + Sync>,
712     ) {
713         log::trace!(
714             "[{}]: ping STUN from {} to {}",
715             self.get_name(),
716             local,
717             remote
718         );
719 
720         self.invalidate_pending_binding_requests(Instant::now())
721             .await;
722         {
723             let mut pending_binding_requests = self.pending_binding_requests.lock().await;
724             pending_binding_requests.push(BindingRequest {
725                 timestamp: Instant::now(),
726                 transaction_id: m.transaction_id,
727                 destination: remote.addr(),
728                 is_use_candidate: m.contains(ATTR_USE_CANDIDATE),
729             });
730         }
731 
732         self.send_stun(m, local, remote).await;
733     }
734 
send_binding_success( &self, m: &Message, local: &Arc<dyn Candidate + Send + Sync>, remote: &Arc<dyn Candidate + Send + Sync>, )735     pub(crate) async fn send_binding_success(
736         &self,
737         m: &Message,
738         local: &Arc<dyn Candidate + Send + Sync>,
739         remote: &Arc<dyn Candidate + Send + Sync>,
740     ) {
741         let addr = remote.addr();
742         let (ip, port) = (addr.ip(), addr.port());
743         let local_pwd = {
744             let ufrag_pwd = self.ufrag_pwd.lock().await;
745             ufrag_pwd.local_pwd.clone()
746         };
747 
748         let (out, result) = {
749             let mut out = Message::new();
750             let result = out.build(&[
751                 Box::new(m.clone()),
752                 Box::new(BINDING_SUCCESS),
753                 Box::new(XorMappedAddress { ip, port }),
754                 Box::new(MessageIntegrity::new_short_term_integrity(local_pwd)),
755                 Box::new(FINGERPRINT),
756             ]);
757             (out, result)
758         };
759 
760         if let Err(err) = result {
761             log::warn!(
762                 "[{}]: Failed to handle inbound ICE from: {} to: {} error: {}",
763                 self.get_name(),
764                 local,
765                 remote,
766                 err
767             );
768         } else {
769             self.send_stun(&out, local, remote).await;
770         }
771     }
772 
773     /// Removes pending binding requests that are over `maxBindingRequestTimeout` old Let HTO be the
774     /// transaction timeout, which SHOULD be 2*RTT if RTT is known or 500 ms otherwise.
775     ///
776     /// reference: (IETF ref-8445)[https://tools.ietf.org/html/rfc8445#appendix-B.1].
invalidate_pending_binding_requests(&self, filter_time: Instant)777     pub(crate) async fn invalidate_pending_binding_requests(&self, filter_time: Instant) {
778         let mut pending_binding_requests = self.pending_binding_requests.lock().await;
779         let initial_size = pending_binding_requests.len();
780 
781         let mut temp = vec![];
782         for binding_request in pending_binding_requests.drain(..) {
783             if filter_time
784                 .checked_duration_since(binding_request.timestamp)
785                 .map(|duration| duration < MAX_BINDING_REQUEST_TIMEOUT)
786                 .unwrap_or(true)
787             {
788                 temp.push(binding_request);
789             }
790         }
791 
792         *pending_binding_requests = temp;
793         let bind_requests_removed = initial_size - pending_binding_requests.len();
794         if bind_requests_removed > 0 {
795             log::trace!(
796                 "[{}]: Discarded {} binding requests because they expired",
797                 self.get_name(),
798                 bind_requests_removed
799             );
800         }
801     }
802 
803     /// Assert that the passed `TransactionID` is in our `pendingBindingRequests` and returns the
804     /// destination, If the bindingRequest was valid remove it from our pending cache.
handle_inbound_binding_success( &self, id: TransactionId, ) -> Option<BindingRequest>805     pub(crate) async fn handle_inbound_binding_success(
806         &self,
807         id: TransactionId,
808     ) -> Option<BindingRequest> {
809         self.invalidate_pending_binding_requests(Instant::now())
810             .await;
811 
812         let mut pending_binding_requests = self.pending_binding_requests.lock().await;
813         for i in 0..pending_binding_requests.len() {
814             if pending_binding_requests[i].transaction_id == id {
815                 let valid_binding_request = pending_binding_requests.remove(i);
816                 return Some(valid_binding_request);
817             }
818         }
819         None
820     }
821 
822     /// Processes STUN traffic from a remote candidate.
handle_inbound( &self, m: &mut Message, local: &Arc<dyn Candidate + Send + Sync>, remote: SocketAddr, )823     pub(crate) async fn handle_inbound(
824         &self,
825         m: &mut Message,
826         local: &Arc<dyn Candidate + Send + Sync>,
827         remote: SocketAddr,
828     ) {
829         if m.typ.method != METHOD_BINDING
830             || !(m.typ.class == CLASS_SUCCESS_RESPONSE
831                 || m.typ.class == CLASS_REQUEST
832                 || m.typ.class == CLASS_INDICATION)
833         {
834             log::trace!(
835                 "[{}]: unhandled STUN from {} to {} class({}) method({})",
836                 self.get_name(),
837                 remote,
838                 local,
839                 m.typ.class,
840                 m.typ.method
841             );
842             return;
843         }
844 
845         if self.is_controlling.load(Ordering::SeqCst) {
846             if m.contains(ATTR_ICE_CONTROLLING) {
847                 log::debug!(
848                     "[{}]: inbound isControlling && a.isControlling == true",
849                     self.get_name(),
850                 );
851                 return;
852             } else if m.contains(ATTR_USE_CANDIDATE) {
853                 log::debug!(
854                     "[{}]: useCandidate && a.isControlling == true",
855                     self.get_name(),
856                 );
857                 return;
858             }
859         } else if m.contains(ATTR_ICE_CONTROLLED) {
860             log::debug!(
861                 "[{}]: inbound isControlled && a.isControlling == false",
862                 self.get_name(),
863             );
864             return;
865         }
866 
867         let mut remote_candidate = self
868             .find_remote_candidate(local.network_type(), remote)
869             .await;
870         if m.typ.class == CLASS_SUCCESS_RESPONSE {
871             {
872                 let ufrag_pwd = self.ufrag_pwd.lock().await;
873                 if let Err(err) =
874                     assert_inbound_message_integrity(m, ufrag_pwd.remote_pwd.as_bytes())
875                 {
876                     log::warn!(
877                         "[{}]: discard message from ({}), {}",
878                         self.get_name(),
879                         remote,
880                         err
881                     );
882                     return;
883                 }
884             }
885 
886             if let Some(rc) = &remote_candidate {
887                 self.handle_success_response(m, local, rc, remote).await;
888             } else {
889                 log::warn!(
890                     "[{}]: discard success message from ({}), no such remote",
891                     self.get_name(),
892                     remote
893                 );
894                 return;
895             }
896         } else if m.typ.class == CLASS_REQUEST {
897             {
898                 let ufrag_pwd = self.ufrag_pwd.lock().await;
899                 let username =
900                     ufrag_pwd.local_ufrag.clone() + ":" + ufrag_pwd.remote_ufrag.as_str();
901                 if let Err(err) = assert_inbound_username(m, &username) {
902                     log::warn!(
903                         "[{}]: discard message from ({}), {}",
904                         self.get_name(),
905                         remote,
906                         err
907                     );
908                     return;
909                 } else if let Err(err) =
910                     assert_inbound_message_integrity(m, ufrag_pwd.local_pwd.as_bytes())
911                 {
912                     log::warn!(
913                         "[{}]: discard message from ({}), {}",
914                         self.get_name(),
915                         remote,
916                         err
917                     );
918                     return;
919                 }
920             }
921 
922             if remote_candidate.is_none() {
923                 let (ip, port, network_type) = (remote.ip(), remote.port(), NetworkType::Udp4);
924 
925                 let prflx_candidate_config = CandidatePeerReflexiveConfig {
926                     base_config: CandidateBaseConfig {
927                         network: network_type.to_string(),
928                         address: ip.to_string(),
929                         port,
930                         component: local.component(),
931                         ..CandidateBaseConfig::default()
932                     },
933                     rel_addr: "".to_owned(),
934                     rel_port: 0,
935                 };
936 
937                 match prflx_candidate_config.new_candidate_peer_reflexive() {
938                     Ok(prflx_candidate) => remote_candidate = Some(Arc::new(prflx_candidate)),
939                     Err(err) => {
940                         log::error!(
941                             "[{}]: Failed to create new remote prflx candidate ({})",
942                             self.get_name(),
943                             err
944                         );
945                         return;
946                     }
947                 };
948 
949                 log::debug!(
950                     "[{}]: adding a new peer-reflexive candidate: {} ",
951                     self.get_name(),
952                     remote
953                 );
954                 if let Some(rc) = &remote_candidate {
955                     self.add_remote_candidate(rc).await;
956                 }
957             }
958 
959             log::trace!(
960                 "[{}]: inbound STUN (Request) from {} to {}",
961                 self.get_name(),
962                 remote,
963                 local
964             );
965 
966             if let Some(rc) = &remote_candidate {
967                 self.handle_binding_request(m, local, rc).await;
968             }
969         }
970 
971         if let Some(rc) = remote_candidate {
972             rc.seen(false);
973         }
974     }
975 
976     /// Processes non STUN traffic from a remote candidate, and returns true if it is an actual
977     /// remote candidate.
validate_non_stun_traffic( &self, local: &Arc<dyn Candidate + Send + Sync>, remote: SocketAddr, ) -> bool978     pub(crate) async fn validate_non_stun_traffic(
979         &self,
980         local: &Arc<dyn Candidate + Send + Sync>,
981         remote: SocketAddr,
982     ) -> bool {
983         self.find_remote_candidate(local.network_type(), remote)
984             .await
985             .map_or(false, |remote_candidate| {
986                 remote_candidate.seen(false);
987                 true
988             })
989     }
990 
991     /// Sets the credentials of the remote agent.
set_remote_credentials( &self, remote_ufrag: String, remote_pwd: String, ) -> Result<()>992     pub(crate) async fn set_remote_credentials(
993         &self,
994         remote_ufrag: String,
995         remote_pwd: String,
996     ) -> Result<()> {
997         if remote_ufrag.is_empty() {
998             return Err(Error::ErrRemoteUfragEmpty);
999         } else if remote_pwd.is_empty() {
1000             return Err(Error::ErrRemotePwdEmpty);
1001         }
1002 
1003         let mut ufrag_pwd = self.ufrag_pwd.lock().await;
1004         ufrag_pwd.remote_ufrag = remote_ufrag;
1005         ufrag_pwd.remote_pwd = remote_pwd;
1006         Ok(())
1007     }
1008 
send_stun( &self, msg: &Message, local: &Arc<dyn Candidate + Send + Sync>, remote: &Arc<dyn Candidate + Send + Sync>, )1009     pub(crate) async fn send_stun(
1010         &self,
1011         msg: &Message,
1012         local: &Arc<dyn Candidate + Send + Sync>,
1013         remote: &Arc<dyn Candidate + Send + Sync>,
1014     ) {
1015         if let Err(err) = local.write_to(&msg.raw, &**remote).await {
1016             log::trace!(
1017                 "[{}]: failed to send STUN message: {}",
1018                 self.get_name(),
1019                 err
1020             );
1021         }
1022     }
1023 
1024     /// Runs the candidate using the provided connection.
start_candidate( self: &Arc<Self>, candidate: &Arc<dyn Candidate + Send + Sync>, initialized_ch: Option<broadcast::Receiver<()>>, )1025     async fn start_candidate(
1026         self: &Arc<Self>,
1027         candidate: &Arc<dyn Candidate + Send + Sync>,
1028         initialized_ch: Option<broadcast::Receiver<()>>,
1029     ) {
1030         let (closed_ch_tx, closed_ch_rx) = broadcast::channel(1);
1031         {
1032             let closed_ch = candidate.get_closed_ch();
1033             let mut closed = closed_ch.lock().await;
1034             *closed = Some(closed_ch_tx);
1035         }
1036 
1037         let cand = Arc::clone(candidate);
1038         if let Some(conn) = candidate.get_conn() {
1039             let conn = Arc::clone(conn);
1040             let addr = candidate.addr();
1041             let ai = Arc::clone(self);
1042             tokio::spawn(async move {
1043                 let _ = ai
1044                     .recv_loop(cand, closed_ch_rx, initialized_ch, conn, addr)
1045                     .await;
1046             });
1047         } else {
1048             log::error!("[{}]: Can't start due to conn is_none", self.get_name(),);
1049         }
1050     }
1051 
start_on_connection_state_change_routine( self: &Arc<Self>, mut chan_state_rx: mpsc::Receiver<ConnectionState>, mut chan_candidate_rx: mpsc::Receiver<Option<Arc<dyn Candidate + Send + Sync>>>, mut chan_candidate_pair_rx: mpsc::Receiver<()>, )1052     pub(super) fn start_on_connection_state_change_routine(
1053         self: &Arc<Self>,
1054         mut chan_state_rx: mpsc::Receiver<ConnectionState>,
1055         mut chan_candidate_rx: mpsc::Receiver<Option<Arc<dyn Candidate + Send + Sync>>>,
1056         mut chan_candidate_pair_rx: mpsc::Receiver<()>,
1057     ) {
1058         let ai = Arc::clone(self);
1059         tokio::spawn(async move {
1060             // CandidatePair and ConnectionState are usually changed at once.
1061             // Blocking one by the other one causes deadlock.
1062             while chan_candidate_pair_rx.recv().await.is_some() {
1063                 if let (Some(cb), Some(p)) = (
1064                     &*ai.on_selected_candidate_pair_change_hdlr.load(),
1065                     &*ai.agent_conn.selected_pair.load(),
1066                 ) {
1067                     let mut f = cb.lock().await;
1068                     f(&p.local, &p.remote).await;
1069                 }
1070             }
1071         });
1072 
1073         let ai = Arc::clone(self);
1074         tokio::spawn(async move {
1075             loop {
1076                 tokio::select! {
1077                     opt_state = chan_state_rx.recv() => {
1078                         if let Some(s) = opt_state {
1079                             if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() {
1080                                 let mut f = handler.lock().await;
1081                                 f(s).await;
1082                             }
1083                         } else {
1084                             while let Some(c) = chan_candidate_rx.recv().await {
1085                                 if let Some(handler) = &*ai.on_candidate_hdlr.load() {
1086                                     let mut f = handler.lock().await;
1087                                     f(c).await;
1088                                 }
1089                             }
1090                             break;
1091                         }
1092                     },
1093                     opt_cand = chan_candidate_rx.recv() => {
1094                         if let Some(c) = opt_cand {
1095                             if let Some(handler) = &*ai.on_candidate_hdlr.load() {
1096                                 let mut f = handler.lock().await;
1097                                 f(c).await;
1098                             }
1099                         } else {
1100                             while let Some(s) = chan_state_rx.recv().await {
1101                                 if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() {
1102                                     let mut f = handler.lock().await;
1103                                     f(s).await;
1104                                 }
1105                             }
1106                             break;
1107                         }
1108                     }
1109                 }
1110             }
1111         });
1112     }
1113 
recv_loop( self: &Arc<Self>, candidate: Arc<dyn Candidate + Send + Sync>, mut closed_ch_rx: broadcast::Receiver<()>, initialized_ch: Option<broadcast::Receiver<()>>, conn: Arc<dyn util::Conn + Send + Sync>, addr: SocketAddr, ) -> Result<()>1114     async fn recv_loop(
1115         self: &Arc<Self>,
1116         candidate: Arc<dyn Candidate + Send + Sync>,
1117         mut closed_ch_rx: broadcast::Receiver<()>,
1118         initialized_ch: Option<broadcast::Receiver<()>>,
1119         conn: Arc<dyn util::Conn + Send + Sync>,
1120         addr: SocketAddr,
1121     ) -> Result<()> {
1122         if let Some(mut initialized_ch) = initialized_ch {
1123             tokio::select! {
1124                 _ = initialized_ch.recv() => {}
1125                 _ = closed_ch_rx.recv() => return Err(Error::ErrClosed),
1126             }
1127         }
1128 
1129         let mut buffer = vec![0_u8; RECEIVE_MTU];
1130         let mut n;
1131         let mut src_addr;
1132         loop {
1133             tokio::select! {
1134                result = conn.recv_from(&mut buffer) => {
1135                    match result {
1136                        Ok((num, src)) => {
1137                             n = num;
1138                             src_addr = src;
1139                        }
1140                        Err(err) => return Err(Error::Other(err.to_string())),
1141                    }
1142                },
1143                 _  = closed_ch_rx.recv() => return Err(Error::ErrClosed),
1144             }
1145 
1146             self.handle_inbound_candidate_msg(&candidate, &buffer[..n], src_addr, addr)
1147                 .await;
1148         }
1149     }
1150 
handle_inbound_candidate_msg( self: &Arc<Self>, c: &Arc<dyn Candidate + Send + Sync>, buf: &[u8], src_addr: SocketAddr, addr: SocketAddr, )1151     async fn handle_inbound_candidate_msg(
1152         self: &Arc<Self>,
1153         c: &Arc<dyn Candidate + Send + Sync>,
1154         buf: &[u8],
1155         src_addr: SocketAddr,
1156         addr: SocketAddr,
1157     ) {
1158         if stun::message::is_message(buf) {
1159             let mut m = Message {
1160                 raw: vec![],
1161                 ..Message::default()
1162             };
1163             // Explicitly copy raw buffer so Message can own the memory.
1164             m.raw.extend_from_slice(buf);
1165 
1166             if let Err(err) = m.decode() {
1167                 log::warn!(
1168                     "[{}]: Failed to handle decode ICE from {} to {}: {}",
1169                     self.get_name(),
1170                     addr,
1171                     src_addr,
1172                     err
1173                 );
1174             } else {
1175                 self.handle_inbound(&mut m, c, src_addr).await;
1176             }
1177         } else if !self.validate_non_stun_traffic(c, src_addr).await {
1178             log::warn!(
1179                 "[{}]: Discarded message, not a valid remote candidate",
1180                 self.get_name(),
1181                 //c.addr().await //from {}
1182             );
1183         } else if let Err(err) = self.agent_conn.buffer.write(buf).await {
1184             // NOTE This will return packetio.ErrFull if the buffer ever manages to fill up.
1185             log::warn!("[{}]: failed to write packet: {}", self.get_name(), err);
1186         }
1187     }
1188 
get_name(&self) -> &str1189     pub(crate) fn get_name(&self) -> &str {
1190         if self.is_controlling.load(Ordering::SeqCst) {
1191             "controlling"
1192         } else {
1193             "controlled"
1194         }
1195     }
1196 }
1197