xref: /webrtc/turn/src/client/transaction.rs (revision ffe74184)
1 use crate::error::*;
2 
3 use stun::message::*;
4 
5 use std::collections::HashMap;
6 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7 use std::str::FromStr;
8 use std::sync::atomic::{AtomicU16, Ordering};
9 use std::sync::Arc;
10 use tokio::sync::{mpsc, Mutex};
11 use tokio::time::Duration;
12 use util::Conn;
13 
14 const MAX_RTX_INTERVAL_IN_MS: u16 = 1600;
15 const MAX_RTX_COUNT: u16 = 7; // total 7 requests (Rc)
16 
on_rtx_timeout( conn: &Arc<dyn Conn + Send + Sync>, tr_map: &Arc<Mutex<TransactionMap>>, tr_key: &str, n_rtx: u16, ) -> bool17 async fn on_rtx_timeout(
18     conn: &Arc<dyn Conn + Send + Sync>,
19     tr_map: &Arc<Mutex<TransactionMap>>,
20     tr_key: &str,
21     n_rtx: u16,
22 ) -> bool {
23     let mut tm = tr_map.lock().await;
24     let (tr_raw, tr_to) = match tm.find(tr_key) {
25         Some(tr) => (tr.raw.clone(), tr.to.clone()),
26         None => return true, // already gone
27     };
28 
29     if n_rtx == MAX_RTX_COUNT {
30         // all retransmisstions failed
31         if let Some(tr) = tm.delete(tr_key) {
32             if !tr
33                 .write_result(TransactionResult {
34                     err: Some(Error::Other(format!(
35                         "{:?} {}",
36                         Error::ErrAllRetransmissionsFailed,
37                         tr_key
38                     ))),
39                     ..Default::default()
40                 })
41                 .await
42             {
43                 log::debug!("no listener for transaction");
44             }
45         }
46         return true;
47     }
48 
49     log::trace!(
50         "retransmitting transaction {} to {} (n_rtx={})",
51         tr_key,
52         tr_to,
53         n_rtx
54     );
55 
56     let dst = match SocketAddr::from_str(&tr_to) {
57         Ok(dst) => dst,
58         Err(_) => return false,
59     };
60 
61     if conn.send_to(&tr_raw, dst).await.is_err() {
62         if let Some(tr) = tm.delete(tr_key) {
63             if !tr
64                 .write_result(TransactionResult {
65                     err: Some(Error::Other(format!(
66                         "{:?} {}",
67                         Error::ErrAllRetransmissionsFailed,
68                         tr_key
69                     ))),
70                     ..Default::default()
71                 })
72                 .await
73             {
74                 log::debug!("no listener for transaction");
75             }
76         }
77         return true;
78     }
79 
80     false
81 }
82 
83 // TransactionResult is a bag of result values of a transaction
84 #[derive(Debug)] //Clone
85 pub struct TransactionResult {
86     pub msg: Message,
87     pub from: SocketAddr,
88     pub retries: u16,
89     pub err: Option<Error>,
90 }
91 
92 impl Default for TransactionResult {
default() -> Self93     fn default() -> Self {
94         TransactionResult {
95             msg: Message::default(),
96             from: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
97             retries: 0,
98             err: None,
99         }
100     }
101 }
102 
103 // TransactionConfig is a set of config params used by NewTransaction
104 #[derive(Default)]
105 pub struct TransactionConfig {
106     pub key: String,
107     pub raw: Vec<u8>,
108     pub to: String,
109     pub interval: u16,
110     pub ignore_result: bool, // true to throw away the result of this transaction (it will not be readable using wait_for_result)
111 }
112 
113 // Transaction represents a transaction
114 #[derive(Debug)]
115 pub struct Transaction {
116     pub key: String,
117     pub raw: Vec<u8>,
118     pub to: String,
119     pub n_rtx: Arc<AtomicU16>,
120     pub interval: Arc<AtomicU16>,
121     timer_ch_tx: Option<mpsc::Sender<()>>,
122     result_ch_tx: Option<mpsc::Sender<TransactionResult>>,
123     result_ch_rx: Option<mpsc::Receiver<TransactionResult>>,
124 }
125 
126 impl Default for Transaction {
default() -> Self127     fn default() -> Self {
128         Transaction {
129             key: String::new(),
130             raw: vec![],
131             to: String::new(),
132             n_rtx: Arc::new(AtomicU16::new(0)),
133             interval: Arc::new(AtomicU16::new(0)),
134             //timer: None,
135             timer_ch_tx: None,
136             result_ch_tx: None,
137             result_ch_rx: None,
138         }
139     }
140 }
141 
142 impl Transaction {
143     // NewTransaction creates a new instance of Transaction
new(config: TransactionConfig) -> Self144     pub fn new(config: TransactionConfig) -> Self {
145         let (result_ch_tx, result_ch_rx) = if !config.ignore_result {
146             let (tx, rx) = mpsc::channel(1);
147             (Some(tx), Some(rx))
148         } else {
149             (None, None)
150         };
151 
152         Transaction {
153             key: config.key,
154             raw: config.raw,
155             to: config.to,
156             interval: Arc::new(AtomicU16::new(config.interval)),
157             result_ch_tx,
158             result_ch_rx,
159             ..Default::default()
160         }
161     }
162 
163     // start_rtx_timer starts the transaction timer
start_rtx_timer( &mut self, conn: Arc<dyn Conn + Send + Sync>, tr_map: Arc<Mutex<TransactionMap>>, )164     pub async fn start_rtx_timer(
165         &mut self,
166         conn: Arc<dyn Conn + Send + Sync>,
167         tr_map: Arc<Mutex<TransactionMap>>,
168     ) {
169         let (timer_ch_tx, mut timer_ch_rx) = mpsc::channel(1);
170         self.timer_ch_tx = Some(timer_ch_tx);
171         let (n_rtx, interval, key) = (self.n_rtx.clone(), self.interval.clone(), self.key.clone());
172 
173         tokio::spawn(async move {
174             let mut done = false;
175             while !done {
176                 let timer = tokio::time::sleep(Duration::from_millis(
177                     interval.load(Ordering::SeqCst) as u64,
178                 ));
179                 tokio::pin!(timer);
180 
181                 tokio::select! {
182                     _ = timer.as_mut() => {
183                         let rtx = n_rtx.fetch_add(1, Ordering::SeqCst);
184 
185                         let mut val = interval.load(Ordering::SeqCst);
186                         val *= 2;
187                         if val > MAX_RTX_INTERVAL_IN_MS {
188                             val = MAX_RTX_INTERVAL_IN_MS;
189                         }
190                         interval.store(val, Ordering::SeqCst);
191 
192                         done = on_rtx_timeout(&conn, &tr_map, &key, rtx + 1).await;
193                     }
194                     _ = timer_ch_rx.recv() => done = true,
195                 }
196             }
197         });
198     }
199 
200     // stop_rtx_timer stop the transaction timer
stop_rtx_timer(&mut self)201     pub fn stop_rtx_timer(&mut self) {
202         if self.timer_ch_tx.is_some() {
203             self.timer_ch_tx.take();
204         }
205     }
206 
207     // write_result writes the result to the result channel
write_result(&self, res: TransactionResult) -> bool208     pub async fn write_result(&self, res: TransactionResult) -> bool {
209         if let Some(result_ch) = &self.result_ch_tx {
210             result_ch.send(res).await.is_ok()
211         } else {
212             false
213         }
214     }
215 
get_result_channel(&mut self) -> Option<mpsc::Receiver<TransactionResult>>216     pub fn get_result_channel(&mut self) -> Option<mpsc::Receiver<TransactionResult>> {
217         self.result_ch_rx.take()
218     }
219 
220     // Close closes the transaction
close(&mut self)221     pub fn close(&mut self) {
222         if self.result_ch_tx.is_some() {
223             self.result_ch_tx.take();
224         }
225     }
226 
227     // retries returns the number of retransmission it has made
retries(&self) -> u16228     pub fn retries(&self) -> u16 {
229         self.n_rtx.load(Ordering::SeqCst)
230     }
231 }
232 
233 // TransactionMap is a thread-safe transaction map
234 #[derive(Default, Debug)]
235 pub struct TransactionMap {
236     tr_map: HashMap<String, Transaction>,
237 }
238 
239 impl TransactionMap {
240     // NewTransactionMap create a new instance of the transaction map
new() -> TransactionMap241     pub fn new() -> TransactionMap {
242         TransactionMap {
243             tr_map: HashMap::new(),
244         }
245     }
246 
247     // Insert inserts a trasaction to the map
insert(&mut self, key: String, tr: Transaction) -> bool248     pub fn insert(&mut self, key: String, tr: Transaction) -> bool {
249         self.tr_map.insert(key, tr);
250         true
251     }
252 
253     // Find looks up a transaction by its key
find(&self, key: &str) -> Option<&Transaction>254     pub fn find(&self, key: &str) -> Option<&Transaction> {
255         self.tr_map.get(key)
256     }
257 
get(&mut self, key: &str) -> Option<&mut Transaction>258     pub fn get(&mut self, key: &str) -> Option<&mut Transaction> {
259         self.tr_map.get_mut(key)
260     }
261 
262     // Delete deletes a transaction by its key
delete(&mut self, key: &str) -> Option<Transaction>263     pub fn delete(&mut self, key: &str) -> Option<Transaction> {
264         self.tr_map.remove(key)
265     }
266 
267     // close_and_delete_all closes and deletes all transactions
close_and_delete_all(&mut self)268     pub fn close_and_delete_all(&mut self) {
269         for tr in self.tr_map.values_mut() {
270             tr.close();
271         }
272         self.tr_map.clear();
273     }
274 
275     // Size returns the length of the transaction map
size(&self) -> usize276     pub fn size(&self) -> usize {
277         self.tr_map.len()
278     }
279 }
280