1*ffe74184SMartin Algesten #[cfg(test)] 2*ffe74184SMartin Algesten mod buffer_test; 3*ffe74184SMartin Algesten 4*ffe74184SMartin Algesten use crate::error::{Error, Result}; 5*ffe74184SMartin Algesten 6*ffe74184SMartin Algesten use std::sync::Arc; 7*ffe74184SMartin Algesten use tokio::sync::{Mutex, Notify}; 8*ffe74184SMartin Algesten use tokio::time::{timeout, Duration}; 9*ffe74184SMartin Algesten 10*ffe74184SMartin Algesten const MIN_SIZE: usize = 2048; 11*ffe74184SMartin Algesten const CUTOFF_SIZE: usize = 128 * 1024; 12*ffe74184SMartin Algesten const MAX_SIZE: usize = 4 * 1024 * 1024; 13*ffe74184SMartin Algesten 14*ffe74184SMartin Algesten /// Buffer allows writing packets to an intermediate buffer, which can then be read form. 15*ffe74184SMartin Algesten /// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read. 16*ffe74184SMartin Algesten #[derive(Debug)] 17*ffe74184SMartin Algesten struct BufferInternal { 18*ffe74184SMartin Algesten data: Vec<u8>, 19*ffe74184SMartin Algesten head: usize, 20*ffe74184SMartin Algesten tail: usize, 21*ffe74184SMartin Algesten 22*ffe74184SMartin Algesten closed: bool, 23*ffe74184SMartin Algesten subs: bool, 24*ffe74184SMartin Algesten 25*ffe74184SMartin Algesten count: usize, 26*ffe74184SMartin Algesten limit_count: usize, 27*ffe74184SMartin Algesten limit_size: usize, 28*ffe74184SMartin Algesten } 29*ffe74184SMartin Algesten 30*ffe74184SMartin Algesten impl BufferInternal { 31*ffe74184SMartin Algesten /// available returns true if the buffer is large enough to fit a packet 32*ffe74184SMartin Algesten /// of the given size, taking overhead into account. available(&self, size: usize) -> bool33*ffe74184SMartin Algesten fn available(&self, size: usize) -> bool { 34*ffe74184SMartin Algesten let mut available = self.head as isize - self.tail as isize; 35*ffe74184SMartin Algesten if available <= 0 { 36*ffe74184SMartin Algesten available += self.data.len() as isize; 37*ffe74184SMartin Algesten } 38*ffe74184SMartin Algesten // we interpret head=tail as empty, so always keep a byte free 39*ffe74184SMartin Algesten size as isize + 2 < available 40*ffe74184SMartin Algesten } 41*ffe74184SMartin Algesten 42*ffe74184SMartin Algesten /// grow increases the size of the buffer. If it returns nil, then the 43*ffe74184SMartin Algesten /// buffer has been grown. It returns ErrFull if hits a limit. grow(&mut self) -> Result<()>44*ffe74184SMartin Algesten fn grow(&mut self) -> Result<()> { 45*ffe74184SMartin Algesten let mut newsize = if self.data.len() < CUTOFF_SIZE { 46*ffe74184SMartin Algesten 2 * self.data.len() 47*ffe74184SMartin Algesten } else { 48*ffe74184SMartin Algesten 5 * self.data.len() / 4 49*ffe74184SMartin Algesten }; 50*ffe74184SMartin Algesten 51*ffe74184SMartin Algesten if newsize < MIN_SIZE { 52*ffe74184SMartin Algesten newsize = MIN_SIZE 53*ffe74184SMartin Algesten } 54*ffe74184SMartin Algesten if (self.limit_size == 0/*|| sizeHardlimit*/) && newsize > MAX_SIZE { 55*ffe74184SMartin Algesten newsize = MAX_SIZE 56*ffe74184SMartin Algesten } 57*ffe74184SMartin Algesten 58*ffe74184SMartin Algesten // one byte slack 59*ffe74184SMartin Algesten if self.limit_size > 0 && newsize > self.limit_size + 1 { 60*ffe74184SMartin Algesten newsize = self.limit_size + 1 61*ffe74184SMartin Algesten } 62*ffe74184SMartin Algesten 63*ffe74184SMartin Algesten if newsize <= self.data.len() { 64*ffe74184SMartin Algesten return Err(Error::ErrBufferFull); 65*ffe74184SMartin Algesten } 66*ffe74184SMartin Algesten 67*ffe74184SMartin Algesten let mut newdata: Vec<u8> = vec![0; newsize]; 68*ffe74184SMartin Algesten 69*ffe74184SMartin Algesten let mut n; 70*ffe74184SMartin Algesten if self.head <= self.tail { 71*ffe74184SMartin Algesten // data was contiguous 72*ffe74184SMartin Algesten n = self.tail - self.head; 73*ffe74184SMartin Algesten newdata[..n].copy_from_slice(&self.data[self.head..self.tail]); 74*ffe74184SMartin Algesten } else { 75*ffe74184SMartin Algesten // data was discontiguous 76*ffe74184SMartin Algesten n = self.data.len() - self.head; 77*ffe74184SMartin Algesten newdata[..n].copy_from_slice(&self.data[self.head..]); 78*ffe74184SMartin Algesten newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]); 79*ffe74184SMartin Algesten n += self.tail; 80*ffe74184SMartin Algesten } 81*ffe74184SMartin Algesten self.head = 0; 82*ffe74184SMartin Algesten self.tail = n; 83*ffe74184SMartin Algesten self.data = newdata; 84*ffe74184SMartin Algesten 85*ffe74184SMartin Algesten Ok(()) 86*ffe74184SMartin Algesten } 87*ffe74184SMartin Algesten size(&self) -> usize88*ffe74184SMartin Algesten fn size(&self) -> usize { 89*ffe74184SMartin Algesten let mut size = self.tail as isize - self.head as isize; 90*ffe74184SMartin Algesten if size < 0 { 91*ffe74184SMartin Algesten size += self.data.len() as isize; 92*ffe74184SMartin Algesten } 93*ffe74184SMartin Algesten size as usize 94*ffe74184SMartin Algesten } 95*ffe74184SMartin Algesten } 96*ffe74184SMartin Algesten 97*ffe74184SMartin Algesten #[derive(Debug, Clone)] 98*ffe74184SMartin Algesten pub struct Buffer { 99*ffe74184SMartin Algesten buffer: Arc<Mutex<BufferInternal>>, 100*ffe74184SMartin Algesten notify: Arc<Notify>, 101*ffe74184SMartin Algesten } 102*ffe74184SMartin Algesten 103*ffe74184SMartin Algesten impl Buffer { new(limit_count: usize, limit_size: usize) -> Self104*ffe74184SMartin Algesten pub fn new(limit_count: usize, limit_size: usize) -> Self { 105*ffe74184SMartin Algesten Buffer { 106*ffe74184SMartin Algesten buffer: Arc::new(Mutex::new(BufferInternal { 107*ffe74184SMartin Algesten data: vec![], 108*ffe74184SMartin Algesten head: 0, 109*ffe74184SMartin Algesten tail: 0, 110*ffe74184SMartin Algesten 111*ffe74184SMartin Algesten closed: false, 112*ffe74184SMartin Algesten subs: false, 113*ffe74184SMartin Algesten 114*ffe74184SMartin Algesten count: 0, 115*ffe74184SMartin Algesten limit_count, 116*ffe74184SMartin Algesten limit_size, 117*ffe74184SMartin Algesten })), 118*ffe74184SMartin Algesten notify: Arc::new(Notify::new()), 119*ffe74184SMartin Algesten } 120*ffe74184SMartin Algesten } 121*ffe74184SMartin Algesten 122*ffe74184SMartin Algesten /// Write appends a copy of the packet data to the buffer. 123*ffe74184SMartin Algesten /// Returns ErrFull if the packet doesn't fit. 124*ffe74184SMartin Algesten /// Note that the packet size is limited to 65536 bytes since v0.11.0 125*ffe74184SMartin Algesten /// due to the internal data structure. write(&self, packet: &[u8]) -> Result<usize>126*ffe74184SMartin Algesten pub async fn write(&self, packet: &[u8]) -> Result<usize> { 127*ffe74184SMartin Algesten if packet.len() >= 0x10000 { 128*ffe74184SMartin Algesten return Err(Error::ErrPacketTooBig); 129*ffe74184SMartin Algesten } 130*ffe74184SMartin Algesten 131*ffe74184SMartin Algesten let mut b = self.buffer.lock().await; 132*ffe74184SMartin Algesten 133*ffe74184SMartin Algesten if b.closed { 134*ffe74184SMartin Algesten return Err(Error::ErrBufferClosed); 135*ffe74184SMartin Algesten } 136*ffe74184SMartin Algesten 137*ffe74184SMartin Algesten if (b.limit_count > 0 && b.count >= b.limit_count) 138*ffe74184SMartin Algesten || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size) 139*ffe74184SMartin Algesten { 140*ffe74184SMartin Algesten return Err(Error::ErrBufferFull); 141*ffe74184SMartin Algesten } 142*ffe74184SMartin Algesten 143*ffe74184SMartin Algesten // grow the buffer until the packet fits 144*ffe74184SMartin Algesten while !b.available(packet.len()) { 145*ffe74184SMartin Algesten b.grow()?; 146*ffe74184SMartin Algesten } 147*ffe74184SMartin Algesten 148*ffe74184SMartin Algesten // store the length of the packet 149*ffe74184SMartin Algesten let tail = b.tail; 150*ffe74184SMartin Algesten b.data[tail] = (packet.len() >> 8) as u8; 151*ffe74184SMartin Algesten b.tail += 1; 152*ffe74184SMartin Algesten if b.tail >= b.data.len() { 153*ffe74184SMartin Algesten b.tail = 0; 154*ffe74184SMartin Algesten } 155*ffe74184SMartin Algesten 156*ffe74184SMartin Algesten let tail = b.tail; 157*ffe74184SMartin Algesten b.data[tail] = packet.len() as u8; 158*ffe74184SMartin Algesten b.tail += 1; 159*ffe74184SMartin Algesten if b.tail >= b.data.len() { 160*ffe74184SMartin Algesten b.tail = 0; 161*ffe74184SMartin Algesten } 162*ffe74184SMartin Algesten 163*ffe74184SMartin Algesten // store the packet 164*ffe74184SMartin Algesten let end = std::cmp::min(b.data.len(), b.tail + packet.len()); 165*ffe74184SMartin Algesten let n = end - b.tail; 166*ffe74184SMartin Algesten let tail = b.tail; 167*ffe74184SMartin Algesten b.data[tail..end].copy_from_slice(&packet[..n]); 168*ffe74184SMartin Algesten b.tail += n; 169*ffe74184SMartin Algesten if b.tail >= b.data.len() { 170*ffe74184SMartin Algesten // we reached the end, wrap around 171*ffe74184SMartin Algesten let m = packet.len() - n; 172*ffe74184SMartin Algesten b.data[..m].copy_from_slice(&packet[n..]); 173*ffe74184SMartin Algesten b.tail = m; 174*ffe74184SMartin Algesten } 175*ffe74184SMartin Algesten b.count += 1; 176*ffe74184SMartin Algesten 177*ffe74184SMartin Algesten if b.subs { 178*ffe74184SMartin Algesten // we have other are waiting data 179*ffe74184SMartin Algesten self.notify.notify_one(); 180*ffe74184SMartin Algesten b.subs = false; 181*ffe74184SMartin Algesten } 182*ffe74184SMartin Algesten 183*ffe74184SMartin Algesten Ok(packet.len()) 184*ffe74184SMartin Algesten } 185*ffe74184SMartin Algesten 186*ffe74184SMartin Algesten // Read populates the given byte slice, returning the number of bytes read. 187*ffe74184SMartin Algesten // Blocks until data is available or the buffer is closed. 188*ffe74184SMartin Algesten // Returns io.ErrShortBuffer is the packet is too small to copy the Write. 189*ffe74184SMartin Algesten // Returns io.EOF if the buffer is closed. read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize>190*ffe74184SMartin Algesten pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> { 191*ffe74184SMartin Algesten loop { 192*ffe74184SMartin Algesten { 193*ffe74184SMartin Algesten // use {} to let LockGuard RAII 194*ffe74184SMartin Algesten let mut b = self.buffer.lock().await; 195*ffe74184SMartin Algesten 196*ffe74184SMartin Algesten if b.head != b.tail { 197*ffe74184SMartin Algesten // decode the packet size 198*ffe74184SMartin Algesten let n1 = b.data[b.head]; 199*ffe74184SMartin Algesten b.head += 1; 200*ffe74184SMartin Algesten if b.head >= b.data.len() { 201*ffe74184SMartin Algesten b.head = 0; 202*ffe74184SMartin Algesten } 203*ffe74184SMartin Algesten let n2 = b.data[b.head]; 204*ffe74184SMartin Algesten b.head += 1; 205*ffe74184SMartin Algesten if b.head >= b.data.len() { 206*ffe74184SMartin Algesten b.head = 0; 207*ffe74184SMartin Algesten } 208*ffe74184SMartin Algesten let count = ((n1 as usize) << 8) | n2 as usize; 209*ffe74184SMartin Algesten 210*ffe74184SMartin Algesten // determine the number of bytes we'll actually copy 211*ffe74184SMartin Algesten let mut copied = count; 212*ffe74184SMartin Algesten if copied > packet.len() { 213*ffe74184SMartin Algesten copied = packet.len(); 214*ffe74184SMartin Algesten } 215*ffe74184SMartin Algesten 216*ffe74184SMartin Algesten // copy the data 217*ffe74184SMartin Algesten if b.head + copied < b.data.len() { 218*ffe74184SMartin Algesten packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]); 219*ffe74184SMartin Algesten } else { 220*ffe74184SMartin Algesten let k = b.data.len() - b.head; 221*ffe74184SMartin Algesten packet[..k].copy_from_slice(&b.data[b.head..]); 222*ffe74184SMartin Algesten packet[k..copied].copy_from_slice(&b.data[..copied - k]); 223*ffe74184SMartin Algesten } 224*ffe74184SMartin Algesten 225*ffe74184SMartin Algesten // advance head, discarding any data that wasn't copied 226*ffe74184SMartin Algesten b.head += count; 227*ffe74184SMartin Algesten if b.head >= b.data.len() { 228*ffe74184SMartin Algesten b.head -= b.data.len(); 229*ffe74184SMartin Algesten } 230*ffe74184SMartin Algesten 231*ffe74184SMartin Algesten if b.head == b.tail { 232*ffe74184SMartin Algesten // the buffer is empty, reset to beginning 233*ffe74184SMartin Algesten // in order to improve cache locality. 234*ffe74184SMartin Algesten b.head = 0; 235*ffe74184SMartin Algesten b.tail = 0; 236*ffe74184SMartin Algesten } 237*ffe74184SMartin Algesten 238*ffe74184SMartin Algesten b.count -= 1; 239*ffe74184SMartin Algesten 240*ffe74184SMartin Algesten if copied < count { 241*ffe74184SMartin Algesten return Err(Error::ErrBufferShort); 242*ffe74184SMartin Algesten } 243*ffe74184SMartin Algesten return Ok(copied); 244*ffe74184SMartin Algesten } else { 245*ffe74184SMartin Algesten // Dont have data -> need wait 246*ffe74184SMartin Algesten b.subs = true; 247*ffe74184SMartin Algesten } 248*ffe74184SMartin Algesten 249*ffe74184SMartin Algesten if b.closed { 250*ffe74184SMartin Algesten return Err(Error::ErrBufferClosed); 251*ffe74184SMartin Algesten } 252*ffe74184SMartin Algesten } 253*ffe74184SMartin Algesten 254*ffe74184SMartin Algesten // Wait for signal. 255*ffe74184SMartin Algesten if let Some(d) = duration { 256*ffe74184SMartin Algesten if timeout(d, self.notify.notified()).await.is_err() { 257*ffe74184SMartin Algesten return Err(Error::ErrTimeout); 258*ffe74184SMartin Algesten } 259*ffe74184SMartin Algesten } else { 260*ffe74184SMartin Algesten self.notify.notified().await; 261*ffe74184SMartin Algesten } 262*ffe74184SMartin Algesten } 263*ffe74184SMartin Algesten } 264*ffe74184SMartin Algesten 265*ffe74184SMartin Algesten // Close will unblock any readers and prevent future writes. 266*ffe74184SMartin Algesten // Data in the buffer can still be read, returning io.EOF when fully depleted. close(&self)267*ffe74184SMartin Algesten pub async fn close(&self) { 268*ffe74184SMartin Algesten // note: We don't use defer so we can close the notify channel after unlocking. 269*ffe74184SMartin Algesten // This will unblock goroutines that can grab the lock immediately, instead of blocking again. 270*ffe74184SMartin Algesten let mut b = self.buffer.lock().await; 271*ffe74184SMartin Algesten 272*ffe74184SMartin Algesten if b.closed { 273*ffe74184SMartin Algesten return; 274*ffe74184SMartin Algesten } 275*ffe74184SMartin Algesten 276*ffe74184SMartin Algesten b.closed = true; 277*ffe74184SMartin Algesten self.notify.notify_waiters(); 278*ffe74184SMartin Algesten } 279*ffe74184SMartin Algesten is_closed(&self) -> bool280*ffe74184SMartin Algesten pub async fn is_closed(&self) -> bool { 281*ffe74184SMartin Algesten let b = self.buffer.lock().await; 282*ffe74184SMartin Algesten 283*ffe74184SMartin Algesten b.closed 284*ffe74184SMartin Algesten } 285*ffe74184SMartin Algesten 286*ffe74184SMartin Algesten // Count returns the number of packets in the buffer. count(&self) -> usize287*ffe74184SMartin Algesten pub async fn count(&self) -> usize { 288*ffe74184SMartin Algesten let b = self.buffer.lock().await; 289*ffe74184SMartin Algesten 290*ffe74184SMartin Algesten b.count 291*ffe74184SMartin Algesten } 292*ffe74184SMartin Algesten 293*ffe74184SMartin Algesten // set_limit_count controls the maximum number of packets that can be buffered. 294*ffe74184SMartin Algesten // Causes Write to return ErrFull when this limit is reached. 295*ffe74184SMartin Algesten // A zero value will disable this limit. set_limit_count(&self, limit: usize)296*ffe74184SMartin Algesten pub async fn set_limit_count(&self, limit: usize) { 297*ffe74184SMartin Algesten let mut b = self.buffer.lock().await; 298*ffe74184SMartin Algesten 299*ffe74184SMartin Algesten b.limit_count = limit 300*ffe74184SMartin Algesten } 301*ffe74184SMartin Algesten 302*ffe74184SMartin Algesten // Size returns the total byte size of packets in the buffer. size(&self) -> usize303*ffe74184SMartin Algesten pub async fn size(&self) -> usize { 304*ffe74184SMartin Algesten let b = self.buffer.lock().await; 305*ffe74184SMartin Algesten 306*ffe74184SMartin Algesten b.size() 307*ffe74184SMartin Algesten } 308*ffe74184SMartin Algesten 309*ffe74184SMartin Algesten // set_limit_size controls the maximum number of bytes that can be buffered. 310*ffe74184SMartin Algesten // Causes Write to return ErrFull when this limit is reached. 311*ffe74184SMartin Algesten // A zero value means 4MB since v0.11.0. 312*ffe74184SMartin Algesten // 313*ffe74184SMartin Algesten // User can set packetioSizeHardlimit build tag to enable 4MB hardlimit. 314*ffe74184SMartin Algesten // When packetioSizeHardlimit build tag is set, set_limit_size exceeding 315*ffe74184SMartin Algesten // the hardlimit will be silently discarded. set_limit_size(&self, limit: usize)316*ffe74184SMartin Algesten pub async fn set_limit_size(&self, limit: usize) { 317*ffe74184SMartin Algesten let mut b = self.buffer.lock().await; 318*ffe74184SMartin Algesten 319*ffe74184SMartin Algesten b.limit_size = limit 320*ffe74184SMartin Algesten } 321*ffe74184SMartin Algesten } 322