1 #[cfg(test)] 2 mod buffer_test; 3 4 use crate::error::{Error, Result}; 5 6 use std::sync::Arc; 7 use tokio::sync::{Mutex, Notify}; 8 use tokio::time::{timeout, Duration}; 9 10 const MIN_SIZE: usize = 2048; 11 const CUTOFF_SIZE: usize = 128 * 1024; 12 const MAX_SIZE: usize = 4 * 1024 * 1024; 13 14 /// Buffer allows writing packets to an intermediate buffer, which can then be read form. 15 /// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read. 16 #[derive(Debug)] 17 struct BufferInternal { 18 data: Vec<u8>, 19 head: usize, 20 tail: usize, 21 22 closed: bool, 23 subs: bool, 24 25 count: usize, 26 limit_count: usize, 27 limit_size: usize, 28 } 29 30 impl BufferInternal { 31 /// available returns true if the buffer is large enough to fit a packet 32 /// of the given size, taking overhead into account. available(&self, size: usize) -> bool33 fn available(&self, size: usize) -> bool { 34 let mut available = self.head as isize - self.tail as isize; 35 if available <= 0 { 36 available += self.data.len() as isize; 37 } 38 // we interpret head=tail as empty, so always keep a byte free 39 size as isize + 2 < available 40 } 41 42 /// grow increases the size of the buffer. If it returns nil, then the 43 /// buffer has been grown. It returns ErrFull if hits a limit. grow(&mut self) -> Result<()>44 fn grow(&mut self) -> Result<()> { 45 let mut newsize = if self.data.len() < CUTOFF_SIZE { 46 2 * self.data.len() 47 } else { 48 5 * self.data.len() / 4 49 }; 50 51 if newsize < MIN_SIZE { 52 newsize = MIN_SIZE 53 } 54 if (self.limit_size == 0/*|| sizeHardlimit*/) && newsize > MAX_SIZE { 55 newsize = MAX_SIZE 56 } 57 58 // one byte slack 59 if self.limit_size > 0 && newsize > self.limit_size + 1 { 60 newsize = self.limit_size + 1 61 } 62 63 if newsize <= self.data.len() { 64 return Err(Error::ErrBufferFull); 65 } 66 67 let mut newdata: Vec<u8> = vec![0; newsize]; 68 69 let mut n; 70 if self.head <= self.tail { 71 // data was contiguous 72 n = self.tail - self.head; 73 newdata[..n].copy_from_slice(&self.data[self.head..self.tail]); 74 } else { 75 // data was discontiguous 76 n = self.data.len() - self.head; 77 newdata[..n].copy_from_slice(&self.data[self.head..]); 78 newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]); 79 n += self.tail; 80 } 81 self.head = 0; 82 self.tail = n; 83 self.data = newdata; 84 85 Ok(()) 86 } 87 size(&self) -> usize88 fn size(&self) -> usize { 89 let mut size = self.tail as isize - self.head as isize; 90 if size < 0 { 91 size += self.data.len() as isize; 92 } 93 size as usize 94 } 95 } 96 97 #[derive(Debug, Clone)] 98 pub struct Buffer { 99 buffer: Arc<Mutex<BufferInternal>>, 100 notify: Arc<Notify>, 101 } 102 103 impl Buffer { new(limit_count: usize, limit_size: usize) -> Self104 pub fn new(limit_count: usize, limit_size: usize) -> Self { 105 Buffer { 106 buffer: Arc::new(Mutex::new(BufferInternal { 107 data: vec![], 108 head: 0, 109 tail: 0, 110 111 closed: false, 112 subs: false, 113 114 count: 0, 115 limit_count, 116 limit_size, 117 })), 118 notify: Arc::new(Notify::new()), 119 } 120 } 121 122 /// Write appends a copy of the packet data to the buffer. 123 /// Returns ErrFull if the packet doesn't fit. 124 /// Note that the packet size is limited to 65536 bytes since v0.11.0 125 /// due to the internal data structure. write(&self, packet: &[u8]) -> Result<usize>126 pub async fn write(&self, packet: &[u8]) -> Result<usize> { 127 if packet.len() >= 0x10000 { 128 return Err(Error::ErrPacketTooBig); 129 } 130 131 let mut b = self.buffer.lock().await; 132 133 if b.closed { 134 return Err(Error::ErrBufferClosed); 135 } 136 137 if (b.limit_count > 0 && b.count >= b.limit_count) 138 || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size) 139 { 140 return Err(Error::ErrBufferFull); 141 } 142 143 // grow the buffer until the packet fits 144 while !b.available(packet.len()) { 145 b.grow()?; 146 } 147 148 // store the length of the packet 149 let tail = b.tail; 150 b.data[tail] = (packet.len() >> 8) as u8; 151 b.tail += 1; 152 if b.tail >= b.data.len() { 153 b.tail = 0; 154 } 155 156 let tail = b.tail; 157 b.data[tail] = packet.len() as u8; 158 b.tail += 1; 159 if b.tail >= b.data.len() { 160 b.tail = 0; 161 } 162 163 // store the packet 164 let end = std::cmp::min(b.data.len(), b.tail + packet.len()); 165 let n = end - b.tail; 166 let tail = b.tail; 167 b.data[tail..end].copy_from_slice(&packet[..n]); 168 b.tail += n; 169 if b.tail >= b.data.len() { 170 // we reached the end, wrap around 171 let m = packet.len() - n; 172 b.data[..m].copy_from_slice(&packet[n..]); 173 b.tail = m; 174 } 175 b.count += 1; 176 177 if b.subs { 178 // we have other are waiting data 179 self.notify.notify_one(); 180 b.subs = false; 181 } 182 183 Ok(packet.len()) 184 } 185 186 // Read populates the given byte slice, returning the number of bytes read. 187 // Blocks until data is available or the buffer is closed. 188 // Returns io.ErrShortBuffer is the packet is too small to copy the Write. 189 // Returns io.EOF if the buffer is closed. read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize>190 pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> { 191 loop { 192 { 193 // use {} to let LockGuard RAII 194 let mut b = self.buffer.lock().await; 195 196 if b.head != b.tail { 197 // decode the packet size 198 let n1 = b.data[b.head]; 199 b.head += 1; 200 if b.head >= b.data.len() { 201 b.head = 0; 202 } 203 let n2 = b.data[b.head]; 204 b.head += 1; 205 if b.head >= b.data.len() { 206 b.head = 0; 207 } 208 let count = ((n1 as usize) << 8) | n2 as usize; 209 210 // determine the number of bytes we'll actually copy 211 let mut copied = count; 212 if copied > packet.len() { 213 copied = packet.len(); 214 } 215 216 // copy the data 217 if b.head + copied < b.data.len() { 218 packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]); 219 } else { 220 let k = b.data.len() - b.head; 221 packet[..k].copy_from_slice(&b.data[b.head..]); 222 packet[k..copied].copy_from_slice(&b.data[..copied - k]); 223 } 224 225 // advance head, discarding any data that wasn't copied 226 b.head += count; 227 if b.head >= b.data.len() { 228 b.head -= b.data.len(); 229 } 230 231 if b.head == b.tail { 232 // the buffer is empty, reset to beginning 233 // in order to improve cache locality. 234 b.head = 0; 235 b.tail = 0; 236 } 237 238 b.count -= 1; 239 240 if copied < count { 241 return Err(Error::ErrBufferShort); 242 } 243 return Ok(copied); 244 } else { 245 // Dont have data -> need wait 246 b.subs = true; 247 } 248 249 if b.closed { 250 return Err(Error::ErrBufferClosed); 251 } 252 } 253 254 // Wait for signal. 255 if let Some(d) = duration { 256 if timeout(d, self.notify.notified()).await.is_err() { 257 return Err(Error::ErrTimeout); 258 } 259 } else { 260 self.notify.notified().await; 261 } 262 } 263 } 264 265 // Close will unblock any readers and prevent future writes. 266 // Data in the buffer can still be read, returning io.EOF when fully depleted. close(&self)267 pub async fn close(&self) { 268 // note: We don't use defer so we can close the notify channel after unlocking. 269 // This will unblock goroutines that can grab the lock immediately, instead of blocking again. 270 let mut b = self.buffer.lock().await; 271 272 if b.closed { 273 return; 274 } 275 276 b.closed = true; 277 self.notify.notify_waiters(); 278 } 279 is_closed(&self) -> bool280 pub async fn is_closed(&self) -> bool { 281 let b = self.buffer.lock().await; 282 283 b.closed 284 } 285 286 // Count returns the number of packets in the buffer. count(&self) -> usize287 pub async fn count(&self) -> usize { 288 let b = self.buffer.lock().await; 289 290 b.count 291 } 292 293 // set_limit_count controls the maximum number of packets that can be buffered. 294 // Causes Write to return ErrFull when this limit is reached. 295 // A zero value will disable this limit. set_limit_count(&self, limit: usize)296 pub async fn set_limit_count(&self, limit: usize) { 297 let mut b = self.buffer.lock().await; 298 299 b.limit_count = limit 300 } 301 302 // Size returns the total byte size of packets in the buffer. size(&self) -> usize303 pub async fn size(&self) -> usize { 304 let b = self.buffer.lock().await; 305 306 b.size() 307 } 308 309 // set_limit_size controls the maximum number of bytes that can be buffered. 310 // Causes Write to return ErrFull when this limit is reached. 311 // A zero value means 4MB since v0.11.0. 312 // 313 // User can set packetioSizeHardlimit build tag to enable 4MB hardlimit. 314 // When packetioSizeHardlimit build tag is set, set_limit_size exceeding 315 // the hardlimit will be silently discarded. set_limit_size(&self, limit: usize)316 pub async fn set_limit_size(&self, limit: usize) { 317 let mut b = self.buffer.lock().await; 318 319 b.limit_size = limit 320 } 321 } 322