xref: /webrtc/stun/src/client.rs (revision 630c46fe)
1 #[cfg(test)]
2 mod client_test;
3 
4 use crate::agent::*;
5 use crate::error::*;
6 use crate::message::*;
7 
8 use util::Conn;
9 
10 use std::collections::HashMap;
11 use std::io::BufReader;
12 use std::marker::{Send, Sync};
13 use std::ops::Add;
14 use std::sync::Arc;
15 use tokio::sync::mpsc;
16 use tokio::time::{self, Duration, Instant};
17 
18 const DEFAULT_TIMEOUT_RATE: Duration = Duration::from_millis(5);
19 const DEFAULT_RTO: Duration = Duration::from_millis(300);
20 const DEFAULT_MAX_ATTEMPTS: u32 = 7;
21 const DEFAULT_MAX_BUFFER_SIZE: usize = 8;
22 
23 /// Collector calls function f with constant rate.
24 ///
25 /// The simple Collector is ticker which calls function on each tick.
26 pub trait Collector {
start( &mut self, rate: Duration, client_agent_tx: Arc<mpsc::Sender<ClientAgent>>, ) -> Result<()>27     fn start(
28         &mut self,
29         rate: Duration,
30         client_agent_tx: Arc<mpsc::Sender<ClientAgent>>,
31     ) -> Result<()>;
close(&mut self) -> Result<()>32     fn close(&mut self) -> Result<()>;
33 }
34 
35 #[derive(Default)]
36 struct TickerCollector {
37     close_tx: Option<mpsc::Sender<()>>,
38 }
39 
40 impl Collector for TickerCollector {
start( &mut self, rate: Duration, client_agent_tx: Arc<mpsc::Sender<ClientAgent>>, ) -> Result<()>41     fn start(
42         &mut self,
43         rate: Duration,
44         client_agent_tx: Arc<mpsc::Sender<ClientAgent>>,
45     ) -> Result<()> {
46         let (close_tx, mut close_rx) = mpsc::channel(1);
47         self.close_tx = Some(close_tx);
48 
49         tokio::spawn(async move {
50             let mut interval = time::interval(rate);
51 
52             loop {
53                 tokio::select! {
54                     _ = close_rx.recv() => break,
55                     _ = interval.tick() => {
56                         if client_agent_tx.send(ClientAgent::Collect(Instant::now())).await.is_err() {
57                             break;
58                         }
59                     }
60                 }
61             }
62         });
63 
64         Ok(())
65     }
66 
close(&mut self) -> Result<()>67     fn close(&mut self) -> Result<()> {
68         if self.close_tx.is_none() {
69             return Err(Error::ErrCollectorClosed);
70         }
71         self.close_tx.take();
72         Ok(())
73     }
74 }
75 
76 /// ClientTransaction represents transaction in progress.
77 /// If transaction is succeed or failed, f will be called
78 /// provided by event.
79 /// Concurrent access is invalid.
80 #[derive(Debug, Clone)]
81 pub struct ClientTransaction {
82     id: TransactionId,
83     attempt: u32,
84     calls: u32,
85     handler: Handler,
86     start: Instant,
87     rto: Duration,
88     raw: Vec<u8>,
89 }
90 
91 impl ClientTransaction {
handle(&mut self, e: Event) -> Result<()>92     pub(crate) fn handle(&mut self, e: Event) -> Result<()> {
93         self.calls += 1;
94         if self.calls == 1 {
95             if let Some(handler) = &self.handler {
96                 handler.send(e)?;
97             }
98         }
99         Ok(())
100     }
101 
next_timeout(&self, now: Instant) -> Instant102     pub(crate) fn next_timeout(&self, now: Instant) -> Instant {
103         now.add((self.attempt + 1) * self.rto)
104     }
105 }
106 
107 struct ClientSettings {
108     buffer_size: usize,
109     rto: Duration,
110     rto_rate: Duration,
111     max_attempts: u32,
112     closed: bool,
113     //handler: Handler,
114     collector: Option<Box<dyn Collector + Send>>,
115     c: Option<Arc<dyn Conn + Send + Sync>>,
116 }
117 
118 impl Default for ClientSettings {
default() -> Self119     fn default() -> Self {
120         ClientSettings {
121             buffer_size: DEFAULT_MAX_BUFFER_SIZE,
122             rto: DEFAULT_RTO,
123             rto_rate: DEFAULT_TIMEOUT_RATE,
124             max_attempts: DEFAULT_MAX_ATTEMPTS,
125             closed: false,
126             //handler: None,
127             collector: None,
128             c: None,
129         }
130     }
131 }
132 
133 #[derive(Default)]
134 pub struct ClientBuilder {
135     settings: ClientSettings,
136 }
137 
138 impl ClientBuilder {
139     // WithHandler sets client handler which is called if Agent emits the Event
140     // with TransactionID that is not currently registered by Client.
141     // Useful for handling Data indications from TURN server.
142     //pub fn with_handler(mut self, handler: Handler) -> Self {
143     //    self.settings.handler = handler;
144     //    self
145     //}
146 
147     /// with_rto sets client RTO as defined in STUN RFC.
with_rto(mut self, rto: Duration) -> Self148     pub fn with_rto(mut self, rto: Duration) -> Self {
149         self.settings.rto = rto;
150         self
151     }
152 
153     /// with_timeout_rate sets RTO timer minimum resolution.
with_timeout_rate(mut self, d: Duration) -> Self154     pub fn with_timeout_rate(mut self, d: Duration) -> Self {
155         self.settings.rto_rate = d;
156         self
157     }
158 
159     /// with_buffer_size sets buffer size.
with_buffer_size(mut self, buffer_size: usize) -> Self160     pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
161         self.settings.buffer_size = buffer_size;
162         self
163     }
164 
165     /// with_collector rests client timeout collector, the implementation
166     /// of ticker which calls function on each tick.
with_collector(mut self, coll: Box<dyn Collector + Send>) -> Self167     pub fn with_collector(mut self, coll: Box<dyn Collector + Send>) -> Self {
168         self.settings.collector = Some(coll);
169         self
170     }
171 
172     /// with_conn sets transport connection
with_conn(mut self, conn: Arc<dyn Conn + Send + Sync>) -> Self173     pub fn with_conn(mut self, conn: Arc<dyn Conn + Send + Sync>) -> Self {
174         self.settings.c = Some(conn);
175         self
176     }
177 
178     /// with_no_retransmit disables retransmissions and sets RTO to
179     /// DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO which will be effectively time out
180     /// if not set.
181     /// Useful for TCP connections where transport handles RTO.
with_no_retransmit(mut self) -> Self182     pub fn with_no_retransmit(mut self) -> Self {
183         self.settings.max_attempts = 0;
184         if self.settings.rto == Duration::from_secs(0) {
185             self.settings.rto = DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO;
186         }
187         self
188     }
189 
new() -> Self190     pub fn new() -> Self {
191         ClientBuilder {
192             settings: ClientSettings::default(),
193         }
194     }
195 
build(self) -> Result<Client>196     pub fn build(self) -> Result<Client> {
197         if self.settings.c.is_none() {
198             return Err(Error::ErrNoConnection);
199         }
200 
201         let client = Client {
202             settings: self.settings,
203             ..Default::default()
204         }
205         .run()?;
206 
207         Ok(client)
208     }
209 }
210 
211 /// Client simulates "connection" to STUN server.
212 #[derive(Default)]
213 pub struct Client {
214     settings: ClientSettings,
215     close_tx: Option<mpsc::Sender<()>>,
216     client_agent_tx: Option<Arc<mpsc::Sender<ClientAgent>>>,
217     handler_tx: Option<Arc<mpsc::UnboundedSender<Event>>>,
218 }
219 
220 impl Client {
read_until_closed( mut close_rx: mpsc::Receiver<()>, c: Arc<dyn Conn + Send + Sync>, client_agent_tx: Arc<mpsc::Sender<ClientAgent>>, )221     async fn read_until_closed(
222         mut close_rx: mpsc::Receiver<()>,
223         c: Arc<dyn Conn + Send + Sync>,
224         client_agent_tx: Arc<mpsc::Sender<ClientAgent>>,
225     ) {
226         let mut msg = Message::new();
227         let mut buf = vec![0; 1024];
228 
229         loop {
230             tokio::select! {
231                 _ = close_rx.recv() => return,
232                 res = c.recv(&mut buf) => {
233                     if let Ok(n) = res {
234                         let mut reader = BufReader::new(&buf[..n]);
235                         let result = msg.read_from(&mut reader);
236                         if result.is_err() {
237                             continue;
238                         }
239 
240                         if client_agent_tx.send(ClientAgent::Process(msg.clone())).await.is_err(){
241                             return;
242                         }
243                     }
244                 }
245             }
246         }
247     }
248 
insert(&mut self, ct: ClientTransaction) -> Result<()>249     fn insert(&mut self, ct: ClientTransaction) -> Result<()> {
250         if self.settings.closed {
251             return Err(Error::ErrClientClosed);
252         }
253 
254         if let Some(handler_tx) = &mut self.handler_tx {
255             handler_tx.send(Event {
256                 event_type: EventType::Insert(ct),
257                 ..Default::default()
258             })?;
259         }
260 
261         Ok(())
262     }
263 
remove(&mut self, id: TransactionId) -> Result<()>264     fn remove(&mut self, id: TransactionId) -> Result<()> {
265         if self.settings.closed {
266             return Err(Error::ErrClientClosed);
267         }
268 
269         if let Some(handler_tx) = &mut self.handler_tx {
270             handler_tx.send(Event {
271                 event_type: EventType::Remove(id),
272                 ..Default::default()
273             })?;
274         }
275 
276         Ok(())
277     }
278 
start( conn: Option<Arc<dyn Conn + Send + Sync>>, mut handler_rx: mpsc::UnboundedReceiver<Event>, client_agent_tx: Arc<mpsc::Sender<ClientAgent>>, mut t: HashMap<TransactionId, ClientTransaction>, max_attempts: u32, )279     fn start(
280         conn: Option<Arc<dyn Conn + Send + Sync>>,
281         mut handler_rx: mpsc::UnboundedReceiver<Event>,
282         client_agent_tx: Arc<mpsc::Sender<ClientAgent>>,
283         mut t: HashMap<TransactionId, ClientTransaction>,
284         max_attempts: u32,
285     ) {
286         tokio::spawn(async move {
287             while let Some(event) = handler_rx.recv().await {
288                 match event.event_type {
289                     EventType::Close => {
290                         break;
291                     }
292                     EventType::Insert(ct) => {
293                         if t.contains_key(&ct.id) {
294                             continue;
295                         }
296                         t.insert(ct.id, ct);
297                     }
298                     EventType::Remove(id) => {
299                         t.remove(&id);
300                     }
301                     EventType::Callback(id) => {
302                         let mut ct = if t.contains_key(&id) {
303                             t.remove(&id).unwrap()
304                         } else {
305                             /*if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
306                                 c.handler(e)
307                             }*/
308                             continue;
309                         };
310 
311                         if ct.attempt >= max_attempts || event.event_body.is_ok() {
312                             if let Some(handler) = ct.handler {
313                                 let _ = handler.send(event);
314                             }
315                             continue;
316                         }
317 
318                         // Doing re-transmission.
319                         ct.attempt += 1;
320 
321                         let raw = ct.raw.clone();
322                         let timeout = ct.next_timeout(Instant::now());
323                         let id = ct.id;
324 
325                         // Starting client transaction.
326                         t.insert(ct.id, ct);
327 
328                         // Starting agent transaction.
329                         if client_agent_tx
330                             .send(ClientAgent::Start(id, timeout))
331                             .await
332                             .is_err()
333                         {
334                             let ct = t.remove(&id).unwrap();
335                             if let Some(handler) = ct.handler {
336                                 let _ = handler.send(event);
337                             }
338                             continue;
339                         }
340 
341                         // Writing message to connection again.
342                         if let Some(c) = &conn {
343                             if c.send(&raw).await.is_err() {
344                                 let _ = client_agent_tx.send(ClientAgent::Stop(id)).await;
345 
346                                 let ct = t.remove(&id).unwrap();
347                                 if let Some(handler) = ct.handler {
348                                     let _ = handler.send(event);
349                                 }
350                                 continue;
351                             }
352                         }
353                     }
354                 };
355             }
356         });
357     }
358 
359     /// close stops internal connection and agent, returning CloseErr on error.
close(&mut self) -> Result<()>360     pub async fn close(&mut self) -> Result<()> {
361         if self.settings.closed {
362             return Err(Error::ErrClientClosed);
363         }
364 
365         self.settings.closed = true;
366 
367         if let Some(collector) = &mut self.settings.collector {
368             let _ = collector.close();
369         }
370         self.settings.collector.take();
371 
372         self.close_tx.take(); //drop close channel
373         if let Some(client_agent_tx) = &mut self.client_agent_tx {
374             let _ = client_agent_tx.send(ClientAgent::Close).await;
375         }
376         self.client_agent_tx.take();
377 
378         if let Some(c) = self.settings.c.take() {
379             c.close().await?;
380         }
381 
382         Ok(())
383     }
384 
run(mut self) -> Result<Self>385     fn run(mut self) -> Result<Self> {
386         let (close_tx, close_rx) = mpsc::channel(1);
387         let (client_agent_tx, client_agent_rx) = mpsc::channel(self.settings.buffer_size);
388         let (handler_tx, handler_rx) = mpsc::unbounded_channel();
389         let t: HashMap<TransactionId, ClientTransaction> = HashMap::new();
390 
391         let client_agent_tx = Arc::new(client_agent_tx);
392         let handler_tx = Arc::new(handler_tx);
393         self.client_agent_tx = Some(Arc::clone(&client_agent_tx));
394         self.handler_tx = Some(Arc::clone(&handler_tx));
395         self.close_tx = Some(close_tx);
396 
397         let conn = if let Some(conn) = &self.settings.c {
398             Arc::clone(conn)
399         } else {
400             return Err(Error::ErrNoConnection);
401         };
402 
403         Client::start(
404             self.settings.c.clone(),
405             handler_rx,
406             Arc::clone(&client_agent_tx),
407             t,
408             self.settings.max_attempts,
409         );
410 
411         let agent = Agent::new(Some(handler_tx));
412         tokio::spawn(async move { Agent::run(agent, client_agent_rx).await });
413 
414         if self.settings.collector.is_none() {
415             self.settings.collector = Some(Box::<TickerCollector>::default());
416         }
417         if let Some(collector) = &mut self.settings.collector {
418             collector.start(self.settings.rto_rate, Arc::clone(&client_agent_tx))?;
419         }
420 
421         let conn_rx = Arc::clone(&conn);
422         tokio::spawn(
423             async move { Client::read_until_closed(close_rx, conn_rx, client_agent_tx).await },
424         );
425 
426         Ok(self)
427     }
428 
send(&mut self, m: &Message, handler: Handler) -> Result<()>429     pub async fn send(&mut self, m: &Message, handler: Handler) -> Result<()> {
430         if self.settings.closed {
431             return Err(Error::ErrClientClosed);
432         }
433 
434         let has_handler = handler.is_some();
435 
436         if handler.is_some() {
437             let t = ClientTransaction {
438                 id: m.transaction_id,
439                 attempt: 0,
440                 calls: 0,
441                 handler,
442                 start: Instant::now(),
443                 rto: self.settings.rto,
444                 raw: m.raw.clone(),
445             };
446             let d = t.next_timeout(t.start);
447             self.insert(t)?;
448 
449             if let Some(client_agent_tx) = &mut self.client_agent_tx {
450                 client_agent_tx
451                     .send(ClientAgent::Start(m.transaction_id, d))
452                     .await?;
453             }
454         }
455 
456         if let Some(c) = &self.settings.c {
457             let result = c.send(&m.raw).await;
458             if result.is_err() && has_handler {
459                 self.remove(m.transaction_id)?;
460 
461                 if let Some(client_agent_tx) = &mut self.client_agent_tx {
462                     client_agent_tx
463                         .send(ClientAgent::Stop(m.transaction_id))
464                         .await?;
465                 }
466             } else if let Err(err) = result {
467                 return Err(Error::Other(err.to_string()));
468             }
469         }
470 
471         Ok(())
472     }
473 }
474