1 use crate::error::*;
2
3 use stun::message::*;
4
5 use std::collections::HashMap;
6 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7 use std::str::FromStr;
8 use std::sync::atomic::{AtomicU16, Ordering};
9 use std::sync::Arc;
10 use tokio::sync::{mpsc, Mutex};
11 use tokio::time::Duration;
12 use util::Conn;
13
14 const MAX_RTX_INTERVAL_IN_MS: u16 = 1600;
15 const MAX_RTX_COUNT: u16 = 7; // total 7 requests (Rc)
16
on_rtx_timeout( conn: &Arc<dyn Conn + Send + Sync>, tr_map: &Arc<Mutex<TransactionMap>>, tr_key: &str, n_rtx: u16, ) -> bool17 async fn on_rtx_timeout(
18 conn: &Arc<dyn Conn + Send + Sync>,
19 tr_map: &Arc<Mutex<TransactionMap>>,
20 tr_key: &str,
21 n_rtx: u16,
22 ) -> bool {
23 let mut tm = tr_map.lock().await;
24 let (tr_raw, tr_to) = match tm.find(tr_key) {
25 Some(tr) => (tr.raw.clone(), tr.to.clone()),
26 None => return true, // already gone
27 };
28
29 if n_rtx == MAX_RTX_COUNT {
30 // all retransmisstions failed
31 if let Some(tr) = tm.delete(tr_key) {
32 if !tr
33 .write_result(TransactionResult {
34 err: Some(Error::Other(format!(
35 "{:?} {}",
36 Error::ErrAllRetransmissionsFailed,
37 tr_key
38 ))),
39 ..Default::default()
40 })
41 .await
42 {
43 log::debug!("no listener for transaction");
44 }
45 }
46 return true;
47 }
48
49 log::trace!(
50 "retransmitting transaction {} to {} (n_rtx={})",
51 tr_key,
52 tr_to,
53 n_rtx
54 );
55
56 let dst = match SocketAddr::from_str(&tr_to) {
57 Ok(dst) => dst,
58 Err(_) => return false,
59 };
60
61 if conn.send_to(&tr_raw, dst).await.is_err() {
62 if let Some(tr) = tm.delete(tr_key) {
63 if !tr
64 .write_result(TransactionResult {
65 err: Some(Error::Other(format!(
66 "{:?} {}",
67 Error::ErrAllRetransmissionsFailed,
68 tr_key
69 ))),
70 ..Default::default()
71 })
72 .await
73 {
74 log::debug!("no listener for transaction");
75 }
76 }
77 return true;
78 }
79
80 false
81 }
82
83 // TransactionResult is a bag of result values of a transaction
84 #[derive(Debug)] //Clone
85 pub struct TransactionResult {
86 pub msg: Message,
87 pub from: SocketAddr,
88 pub retries: u16,
89 pub err: Option<Error>,
90 }
91
92 impl Default for TransactionResult {
default() -> Self93 fn default() -> Self {
94 TransactionResult {
95 msg: Message::default(),
96 from: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
97 retries: 0,
98 err: None,
99 }
100 }
101 }
102
103 // TransactionConfig is a set of config params used by NewTransaction
104 #[derive(Default)]
105 pub struct TransactionConfig {
106 pub key: String,
107 pub raw: Vec<u8>,
108 pub to: String,
109 pub interval: u16,
110 pub ignore_result: bool, // true to throw away the result of this transaction (it will not be readable using wait_for_result)
111 }
112
113 // Transaction represents a transaction
114 #[derive(Debug)]
115 pub struct Transaction {
116 pub key: String,
117 pub raw: Vec<u8>,
118 pub to: String,
119 pub n_rtx: Arc<AtomicU16>,
120 pub interval: Arc<AtomicU16>,
121 timer_ch_tx: Option<mpsc::Sender<()>>,
122 result_ch_tx: Option<mpsc::Sender<TransactionResult>>,
123 result_ch_rx: Option<mpsc::Receiver<TransactionResult>>,
124 }
125
126 impl Default for Transaction {
default() -> Self127 fn default() -> Self {
128 Transaction {
129 key: String::new(),
130 raw: vec![],
131 to: String::new(),
132 n_rtx: Arc::new(AtomicU16::new(0)),
133 interval: Arc::new(AtomicU16::new(0)),
134 //timer: None,
135 timer_ch_tx: None,
136 result_ch_tx: None,
137 result_ch_rx: None,
138 }
139 }
140 }
141
142 impl Transaction {
143 // NewTransaction creates a new instance of Transaction
new(config: TransactionConfig) -> Self144 pub fn new(config: TransactionConfig) -> Self {
145 let (result_ch_tx, result_ch_rx) = if !config.ignore_result {
146 let (tx, rx) = mpsc::channel(1);
147 (Some(tx), Some(rx))
148 } else {
149 (None, None)
150 };
151
152 Transaction {
153 key: config.key,
154 raw: config.raw,
155 to: config.to,
156 interval: Arc::new(AtomicU16::new(config.interval)),
157 result_ch_tx,
158 result_ch_rx,
159 ..Default::default()
160 }
161 }
162
163 // start_rtx_timer starts the transaction timer
start_rtx_timer( &mut self, conn: Arc<dyn Conn + Send + Sync>, tr_map: Arc<Mutex<TransactionMap>>, )164 pub async fn start_rtx_timer(
165 &mut self,
166 conn: Arc<dyn Conn + Send + Sync>,
167 tr_map: Arc<Mutex<TransactionMap>>,
168 ) {
169 let (timer_ch_tx, mut timer_ch_rx) = mpsc::channel(1);
170 self.timer_ch_tx = Some(timer_ch_tx);
171 let (n_rtx, interval, key) = (self.n_rtx.clone(), self.interval.clone(), self.key.clone());
172
173 tokio::spawn(async move {
174 let mut done = false;
175 while !done {
176 let timer = tokio::time::sleep(Duration::from_millis(
177 interval.load(Ordering::SeqCst) as u64,
178 ));
179 tokio::pin!(timer);
180
181 tokio::select! {
182 _ = timer.as_mut() => {
183 let rtx = n_rtx.fetch_add(1, Ordering::SeqCst);
184
185 let mut val = interval.load(Ordering::SeqCst);
186 val *= 2;
187 if val > MAX_RTX_INTERVAL_IN_MS {
188 val = MAX_RTX_INTERVAL_IN_MS;
189 }
190 interval.store(val, Ordering::SeqCst);
191
192 done = on_rtx_timeout(&conn, &tr_map, &key, rtx + 1).await;
193 }
194 _ = timer_ch_rx.recv() => done = true,
195 }
196 }
197 });
198 }
199
200 // stop_rtx_timer stop the transaction timer
stop_rtx_timer(&mut self)201 pub fn stop_rtx_timer(&mut self) {
202 if self.timer_ch_tx.is_some() {
203 self.timer_ch_tx.take();
204 }
205 }
206
207 // write_result writes the result to the result channel
write_result(&self, res: TransactionResult) -> bool208 pub async fn write_result(&self, res: TransactionResult) -> bool {
209 if let Some(result_ch) = &self.result_ch_tx {
210 result_ch.send(res).await.is_ok()
211 } else {
212 false
213 }
214 }
215
get_result_channel(&mut self) -> Option<mpsc::Receiver<TransactionResult>>216 pub fn get_result_channel(&mut self) -> Option<mpsc::Receiver<TransactionResult>> {
217 self.result_ch_rx.take()
218 }
219
220 // Close closes the transaction
close(&mut self)221 pub fn close(&mut self) {
222 if self.result_ch_tx.is_some() {
223 self.result_ch_tx.take();
224 }
225 }
226
227 // retries returns the number of retransmission it has made
retries(&self) -> u16228 pub fn retries(&self) -> u16 {
229 self.n_rtx.load(Ordering::SeqCst)
230 }
231 }
232
233 // TransactionMap is a thread-safe transaction map
234 #[derive(Default, Debug)]
235 pub struct TransactionMap {
236 tr_map: HashMap<String, Transaction>,
237 }
238
239 impl TransactionMap {
240 // NewTransactionMap create a new instance of the transaction map
new() -> TransactionMap241 pub fn new() -> TransactionMap {
242 TransactionMap {
243 tr_map: HashMap::new(),
244 }
245 }
246
247 // Insert inserts a trasaction to the map
insert(&mut self, key: String, tr: Transaction) -> bool248 pub fn insert(&mut self, key: String, tr: Transaction) -> bool {
249 self.tr_map.insert(key, tr);
250 true
251 }
252
253 // Find looks up a transaction by its key
find(&self, key: &str) -> Option<&Transaction>254 pub fn find(&self, key: &str) -> Option<&Transaction> {
255 self.tr_map.get(key)
256 }
257
get(&mut self, key: &str) -> Option<&mut Transaction>258 pub fn get(&mut self, key: &str) -> Option<&mut Transaction> {
259 self.tr_map.get_mut(key)
260 }
261
262 // Delete deletes a transaction by its key
delete(&mut self, key: &str) -> Option<Transaction>263 pub fn delete(&mut self, key: &str) -> Option<Transaction> {
264 self.tr_map.remove(key)
265 }
266
267 // close_and_delete_all closes and deletes all transactions
close_and_delete_all(&mut self)268 pub fn close_and_delete_all(&mut self) {
269 for tr in self.tr_map.values_mut() {
270 tr.close();
271 }
272 self.tr_map.clear();
273 }
274
275 // Size returns the length of the transaction map
size(&self) -> usize276 pub fn size(&self) -> usize {
277 self.tr_map.len()
278 }
279 }
280