xref: /webrtc/util/src/buffer/mod.rs (revision ffe74184)
1*ffe74184SMartin Algesten #[cfg(test)]
2*ffe74184SMartin Algesten mod buffer_test;
3*ffe74184SMartin Algesten 
4*ffe74184SMartin Algesten use crate::error::{Error, Result};
5*ffe74184SMartin Algesten 
6*ffe74184SMartin Algesten use std::sync::Arc;
7*ffe74184SMartin Algesten use tokio::sync::{Mutex, Notify};
8*ffe74184SMartin Algesten use tokio::time::{timeout, Duration};
9*ffe74184SMartin Algesten 
10*ffe74184SMartin Algesten const MIN_SIZE: usize = 2048;
11*ffe74184SMartin Algesten const CUTOFF_SIZE: usize = 128 * 1024;
12*ffe74184SMartin Algesten const MAX_SIZE: usize = 4 * 1024 * 1024;
13*ffe74184SMartin Algesten 
14*ffe74184SMartin Algesten /// Buffer allows writing packets to an intermediate buffer, which can then be read form.
15*ffe74184SMartin Algesten /// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read.
16*ffe74184SMartin Algesten #[derive(Debug)]
17*ffe74184SMartin Algesten struct BufferInternal {
18*ffe74184SMartin Algesten     data: Vec<u8>,
19*ffe74184SMartin Algesten     head: usize,
20*ffe74184SMartin Algesten     tail: usize,
21*ffe74184SMartin Algesten 
22*ffe74184SMartin Algesten     closed: bool,
23*ffe74184SMartin Algesten     subs: bool,
24*ffe74184SMartin Algesten 
25*ffe74184SMartin Algesten     count: usize,
26*ffe74184SMartin Algesten     limit_count: usize,
27*ffe74184SMartin Algesten     limit_size: usize,
28*ffe74184SMartin Algesten }
29*ffe74184SMartin Algesten 
30*ffe74184SMartin Algesten impl BufferInternal {
31*ffe74184SMartin Algesten     /// available returns true if the buffer is large enough to fit a packet
32*ffe74184SMartin Algesten     /// of the given size, taking overhead into account.
available(&self, size: usize) -> bool33*ffe74184SMartin Algesten     fn available(&self, size: usize) -> bool {
34*ffe74184SMartin Algesten         let mut available = self.head as isize - self.tail as isize;
35*ffe74184SMartin Algesten         if available <= 0 {
36*ffe74184SMartin Algesten             available += self.data.len() as isize;
37*ffe74184SMartin Algesten         }
38*ffe74184SMartin Algesten         // we interpret head=tail as empty, so always keep a byte free
39*ffe74184SMartin Algesten         size as isize + 2 < available
40*ffe74184SMartin Algesten     }
41*ffe74184SMartin Algesten 
42*ffe74184SMartin Algesten     /// grow increases the size of the buffer.  If it returns nil, then the
43*ffe74184SMartin Algesten     /// buffer has been grown.  It returns ErrFull if hits a limit.
grow(&mut self) -> Result<()>44*ffe74184SMartin Algesten     fn grow(&mut self) -> Result<()> {
45*ffe74184SMartin Algesten         let mut newsize = if self.data.len() < CUTOFF_SIZE {
46*ffe74184SMartin Algesten             2 * self.data.len()
47*ffe74184SMartin Algesten         } else {
48*ffe74184SMartin Algesten             5 * self.data.len() / 4
49*ffe74184SMartin Algesten         };
50*ffe74184SMartin Algesten 
51*ffe74184SMartin Algesten         if newsize < MIN_SIZE {
52*ffe74184SMartin Algesten             newsize = MIN_SIZE
53*ffe74184SMartin Algesten         }
54*ffe74184SMartin Algesten         if (self.limit_size == 0/*|| sizeHardlimit*/) && newsize > MAX_SIZE {
55*ffe74184SMartin Algesten             newsize = MAX_SIZE
56*ffe74184SMartin Algesten         }
57*ffe74184SMartin Algesten 
58*ffe74184SMartin Algesten         // one byte slack
59*ffe74184SMartin Algesten         if self.limit_size > 0 && newsize > self.limit_size + 1 {
60*ffe74184SMartin Algesten             newsize = self.limit_size + 1
61*ffe74184SMartin Algesten         }
62*ffe74184SMartin Algesten 
63*ffe74184SMartin Algesten         if newsize <= self.data.len() {
64*ffe74184SMartin Algesten             return Err(Error::ErrBufferFull);
65*ffe74184SMartin Algesten         }
66*ffe74184SMartin Algesten 
67*ffe74184SMartin Algesten         let mut newdata: Vec<u8> = vec![0; newsize];
68*ffe74184SMartin Algesten 
69*ffe74184SMartin Algesten         let mut n;
70*ffe74184SMartin Algesten         if self.head <= self.tail {
71*ffe74184SMartin Algesten             // data was contiguous
72*ffe74184SMartin Algesten             n = self.tail - self.head;
73*ffe74184SMartin Algesten             newdata[..n].copy_from_slice(&self.data[self.head..self.tail]);
74*ffe74184SMartin Algesten         } else {
75*ffe74184SMartin Algesten             // data was discontiguous
76*ffe74184SMartin Algesten             n = self.data.len() - self.head;
77*ffe74184SMartin Algesten             newdata[..n].copy_from_slice(&self.data[self.head..]);
78*ffe74184SMartin Algesten             newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]);
79*ffe74184SMartin Algesten             n += self.tail;
80*ffe74184SMartin Algesten         }
81*ffe74184SMartin Algesten         self.head = 0;
82*ffe74184SMartin Algesten         self.tail = n;
83*ffe74184SMartin Algesten         self.data = newdata;
84*ffe74184SMartin Algesten 
85*ffe74184SMartin Algesten         Ok(())
86*ffe74184SMartin Algesten     }
87*ffe74184SMartin Algesten 
size(&self) -> usize88*ffe74184SMartin Algesten     fn size(&self) -> usize {
89*ffe74184SMartin Algesten         let mut size = self.tail as isize - self.head as isize;
90*ffe74184SMartin Algesten         if size < 0 {
91*ffe74184SMartin Algesten             size += self.data.len() as isize;
92*ffe74184SMartin Algesten         }
93*ffe74184SMartin Algesten         size as usize
94*ffe74184SMartin Algesten     }
95*ffe74184SMartin Algesten }
96*ffe74184SMartin Algesten 
97*ffe74184SMartin Algesten #[derive(Debug, Clone)]
98*ffe74184SMartin Algesten pub struct Buffer {
99*ffe74184SMartin Algesten     buffer: Arc<Mutex<BufferInternal>>,
100*ffe74184SMartin Algesten     notify: Arc<Notify>,
101*ffe74184SMartin Algesten }
102*ffe74184SMartin Algesten 
103*ffe74184SMartin Algesten impl Buffer {
new(limit_count: usize, limit_size: usize) -> Self104*ffe74184SMartin Algesten     pub fn new(limit_count: usize, limit_size: usize) -> Self {
105*ffe74184SMartin Algesten         Buffer {
106*ffe74184SMartin Algesten             buffer: Arc::new(Mutex::new(BufferInternal {
107*ffe74184SMartin Algesten                 data: vec![],
108*ffe74184SMartin Algesten                 head: 0,
109*ffe74184SMartin Algesten                 tail: 0,
110*ffe74184SMartin Algesten 
111*ffe74184SMartin Algesten                 closed: false,
112*ffe74184SMartin Algesten                 subs: false,
113*ffe74184SMartin Algesten 
114*ffe74184SMartin Algesten                 count: 0,
115*ffe74184SMartin Algesten                 limit_count,
116*ffe74184SMartin Algesten                 limit_size,
117*ffe74184SMartin Algesten             })),
118*ffe74184SMartin Algesten             notify: Arc::new(Notify::new()),
119*ffe74184SMartin Algesten         }
120*ffe74184SMartin Algesten     }
121*ffe74184SMartin Algesten 
122*ffe74184SMartin Algesten     /// Write appends a copy of the packet data to the buffer.
123*ffe74184SMartin Algesten     /// Returns ErrFull if the packet doesn't fit.
124*ffe74184SMartin Algesten     /// Note that the packet size is limited to 65536 bytes since v0.11.0
125*ffe74184SMartin Algesten     /// due to the internal data structure.
write(&self, packet: &[u8]) -> Result<usize>126*ffe74184SMartin Algesten     pub async fn write(&self, packet: &[u8]) -> Result<usize> {
127*ffe74184SMartin Algesten         if packet.len() >= 0x10000 {
128*ffe74184SMartin Algesten             return Err(Error::ErrPacketTooBig);
129*ffe74184SMartin Algesten         }
130*ffe74184SMartin Algesten 
131*ffe74184SMartin Algesten         let mut b = self.buffer.lock().await;
132*ffe74184SMartin Algesten 
133*ffe74184SMartin Algesten         if b.closed {
134*ffe74184SMartin Algesten             return Err(Error::ErrBufferClosed);
135*ffe74184SMartin Algesten         }
136*ffe74184SMartin Algesten 
137*ffe74184SMartin Algesten         if (b.limit_count > 0 && b.count >= b.limit_count)
138*ffe74184SMartin Algesten             || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size)
139*ffe74184SMartin Algesten         {
140*ffe74184SMartin Algesten             return Err(Error::ErrBufferFull);
141*ffe74184SMartin Algesten         }
142*ffe74184SMartin Algesten 
143*ffe74184SMartin Algesten         // grow the buffer until the packet fits
144*ffe74184SMartin Algesten         while !b.available(packet.len()) {
145*ffe74184SMartin Algesten             b.grow()?;
146*ffe74184SMartin Algesten         }
147*ffe74184SMartin Algesten 
148*ffe74184SMartin Algesten         // store the length of the packet
149*ffe74184SMartin Algesten         let tail = b.tail;
150*ffe74184SMartin Algesten         b.data[tail] = (packet.len() >> 8) as u8;
151*ffe74184SMartin Algesten         b.tail += 1;
152*ffe74184SMartin Algesten         if b.tail >= b.data.len() {
153*ffe74184SMartin Algesten             b.tail = 0;
154*ffe74184SMartin Algesten         }
155*ffe74184SMartin Algesten 
156*ffe74184SMartin Algesten         let tail = b.tail;
157*ffe74184SMartin Algesten         b.data[tail] = packet.len() as u8;
158*ffe74184SMartin Algesten         b.tail += 1;
159*ffe74184SMartin Algesten         if b.tail >= b.data.len() {
160*ffe74184SMartin Algesten             b.tail = 0;
161*ffe74184SMartin Algesten         }
162*ffe74184SMartin Algesten 
163*ffe74184SMartin Algesten         // store the packet
164*ffe74184SMartin Algesten         let end = std::cmp::min(b.data.len(), b.tail + packet.len());
165*ffe74184SMartin Algesten         let n = end - b.tail;
166*ffe74184SMartin Algesten         let tail = b.tail;
167*ffe74184SMartin Algesten         b.data[tail..end].copy_from_slice(&packet[..n]);
168*ffe74184SMartin Algesten         b.tail += n;
169*ffe74184SMartin Algesten         if b.tail >= b.data.len() {
170*ffe74184SMartin Algesten             // we reached the end, wrap around
171*ffe74184SMartin Algesten             let m = packet.len() - n;
172*ffe74184SMartin Algesten             b.data[..m].copy_from_slice(&packet[n..]);
173*ffe74184SMartin Algesten             b.tail = m;
174*ffe74184SMartin Algesten         }
175*ffe74184SMartin Algesten         b.count += 1;
176*ffe74184SMartin Algesten 
177*ffe74184SMartin Algesten         if b.subs {
178*ffe74184SMartin Algesten             // we have other are waiting data
179*ffe74184SMartin Algesten             self.notify.notify_one();
180*ffe74184SMartin Algesten             b.subs = false;
181*ffe74184SMartin Algesten         }
182*ffe74184SMartin Algesten 
183*ffe74184SMartin Algesten         Ok(packet.len())
184*ffe74184SMartin Algesten     }
185*ffe74184SMartin Algesten 
186*ffe74184SMartin Algesten     // Read populates the given byte slice, returning the number of bytes read.
187*ffe74184SMartin Algesten     // Blocks until data is available or the buffer is closed.
188*ffe74184SMartin Algesten     // Returns io.ErrShortBuffer is the packet is too small to copy the Write.
189*ffe74184SMartin Algesten     // Returns io.EOF if the buffer is closed.
read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize>190*ffe74184SMartin Algesten     pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> {
191*ffe74184SMartin Algesten         loop {
192*ffe74184SMartin Algesten             {
193*ffe74184SMartin Algesten                 // use {} to let LockGuard RAII
194*ffe74184SMartin Algesten                 let mut b = self.buffer.lock().await;
195*ffe74184SMartin Algesten 
196*ffe74184SMartin Algesten                 if b.head != b.tail {
197*ffe74184SMartin Algesten                     // decode the packet size
198*ffe74184SMartin Algesten                     let n1 = b.data[b.head];
199*ffe74184SMartin Algesten                     b.head += 1;
200*ffe74184SMartin Algesten                     if b.head >= b.data.len() {
201*ffe74184SMartin Algesten                         b.head = 0;
202*ffe74184SMartin Algesten                     }
203*ffe74184SMartin Algesten                     let n2 = b.data[b.head];
204*ffe74184SMartin Algesten                     b.head += 1;
205*ffe74184SMartin Algesten                     if b.head >= b.data.len() {
206*ffe74184SMartin Algesten                         b.head = 0;
207*ffe74184SMartin Algesten                     }
208*ffe74184SMartin Algesten                     let count = ((n1 as usize) << 8) | n2 as usize;
209*ffe74184SMartin Algesten 
210*ffe74184SMartin Algesten                     // determine the number of bytes we'll actually copy
211*ffe74184SMartin Algesten                     let mut copied = count;
212*ffe74184SMartin Algesten                     if copied > packet.len() {
213*ffe74184SMartin Algesten                         copied = packet.len();
214*ffe74184SMartin Algesten                     }
215*ffe74184SMartin Algesten 
216*ffe74184SMartin Algesten                     // copy the data
217*ffe74184SMartin Algesten                     if b.head + copied < b.data.len() {
218*ffe74184SMartin Algesten                         packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]);
219*ffe74184SMartin Algesten                     } else {
220*ffe74184SMartin Algesten                         let k = b.data.len() - b.head;
221*ffe74184SMartin Algesten                         packet[..k].copy_from_slice(&b.data[b.head..]);
222*ffe74184SMartin Algesten                         packet[k..copied].copy_from_slice(&b.data[..copied - k]);
223*ffe74184SMartin Algesten                     }
224*ffe74184SMartin Algesten 
225*ffe74184SMartin Algesten                     // advance head, discarding any data that wasn't copied
226*ffe74184SMartin Algesten                     b.head += count;
227*ffe74184SMartin Algesten                     if b.head >= b.data.len() {
228*ffe74184SMartin Algesten                         b.head -= b.data.len();
229*ffe74184SMartin Algesten                     }
230*ffe74184SMartin Algesten 
231*ffe74184SMartin Algesten                     if b.head == b.tail {
232*ffe74184SMartin Algesten                         // the buffer is empty, reset to beginning
233*ffe74184SMartin Algesten                         // in order to improve cache locality.
234*ffe74184SMartin Algesten                         b.head = 0;
235*ffe74184SMartin Algesten                         b.tail = 0;
236*ffe74184SMartin Algesten                     }
237*ffe74184SMartin Algesten 
238*ffe74184SMartin Algesten                     b.count -= 1;
239*ffe74184SMartin Algesten 
240*ffe74184SMartin Algesten                     if copied < count {
241*ffe74184SMartin Algesten                         return Err(Error::ErrBufferShort);
242*ffe74184SMartin Algesten                     }
243*ffe74184SMartin Algesten                     return Ok(copied);
244*ffe74184SMartin Algesten                 } else {
245*ffe74184SMartin Algesten                     // Dont have data -> need wait
246*ffe74184SMartin Algesten                     b.subs = true;
247*ffe74184SMartin Algesten                 }
248*ffe74184SMartin Algesten 
249*ffe74184SMartin Algesten                 if b.closed {
250*ffe74184SMartin Algesten                     return Err(Error::ErrBufferClosed);
251*ffe74184SMartin Algesten                 }
252*ffe74184SMartin Algesten             }
253*ffe74184SMartin Algesten 
254*ffe74184SMartin Algesten             // Wait for signal.
255*ffe74184SMartin Algesten             if let Some(d) = duration {
256*ffe74184SMartin Algesten                 if timeout(d, self.notify.notified()).await.is_err() {
257*ffe74184SMartin Algesten                     return Err(Error::ErrTimeout);
258*ffe74184SMartin Algesten                 }
259*ffe74184SMartin Algesten             } else {
260*ffe74184SMartin Algesten                 self.notify.notified().await;
261*ffe74184SMartin Algesten             }
262*ffe74184SMartin Algesten         }
263*ffe74184SMartin Algesten     }
264*ffe74184SMartin Algesten 
265*ffe74184SMartin Algesten     // Close will unblock any readers and prevent future writes.
266*ffe74184SMartin Algesten     // Data in the buffer can still be read, returning io.EOF when fully depleted.
close(&self)267*ffe74184SMartin Algesten     pub async fn close(&self) {
268*ffe74184SMartin Algesten         // note: We don't use defer so we can close the notify channel after unlocking.
269*ffe74184SMartin Algesten         // This will unblock goroutines that can grab the lock immediately, instead of blocking again.
270*ffe74184SMartin Algesten         let mut b = self.buffer.lock().await;
271*ffe74184SMartin Algesten 
272*ffe74184SMartin Algesten         if b.closed {
273*ffe74184SMartin Algesten             return;
274*ffe74184SMartin Algesten         }
275*ffe74184SMartin Algesten 
276*ffe74184SMartin Algesten         b.closed = true;
277*ffe74184SMartin Algesten         self.notify.notify_waiters();
278*ffe74184SMartin Algesten     }
279*ffe74184SMartin Algesten 
is_closed(&self) -> bool280*ffe74184SMartin Algesten     pub async fn is_closed(&self) -> bool {
281*ffe74184SMartin Algesten         let b = self.buffer.lock().await;
282*ffe74184SMartin Algesten 
283*ffe74184SMartin Algesten         b.closed
284*ffe74184SMartin Algesten     }
285*ffe74184SMartin Algesten 
286*ffe74184SMartin Algesten     // Count returns the number of packets in the buffer.
count(&self) -> usize287*ffe74184SMartin Algesten     pub async fn count(&self) -> usize {
288*ffe74184SMartin Algesten         let b = self.buffer.lock().await;
289*ffe74184SMartin Algesten 
290*ffe74184SMartin Algesten         b.count
291*ffe74184SMartin Algesten     }
292*ffe74184SMartin Algesten 
293*ffe74184SMartin Algesten     // set_limit_count controls the maximum number of packets that can be buffered.
294*ffe74184SMartin Algesten     // Causes Write to return ErrFull when this limit is reached.
295*ffe74184SMartin Algesten     // A zero value will disable this limit.
set_limit_count(&self, limit: usize)296*ffe74184SMartin Algesten     pub async fn set_limit_count(&self, limit: usize) {
297*ffe74184SMartin Algesten         let mut b = self.buffer.lock().await;
298*ffe74184SMartin Algesten 
299*ffe74184SMartin Algesten         b.limit_count = limit
300*ffe74184SMartin Algesten     }
301*ffe74184SMartin Algesten 
302*ffe74184SMartin Algesten     // Size returns the total byte size of packets in the buffer.
size(&self) -> usize303*ffe74184SMartin Algesten     pub async fn size(&self) -> usize {
304*ffe74184SMartin Algesten         let b = self.buffer.lock().await;
305*ffe74184SMartin Algesten 
306*ffe74184SMartin Algesten         b.size()
307*ffe74184SMartin Algesten     }
308*ffe74184SMartin Algesten 
309*ffe74184SMartin Algesten     // set_limit_size controls the maximum number of bytes that can be buffered.
310*ffe74184SMartin Algesten     // Causes Write to return ErrFull when this limit is reached.
311*ffe74184SMartin Algesten     // A zero value means 4MB since v0.11.0.
312*ffe74184SMartin Algesten     //
313*ffe74184SMartin Algesten     // User can set packetioSizeHardlimit build tag to enable 4MB hardlimit.
314*ffe74184SMartin Algesten     // When packetioSizeHardlimit build tag is set, set_limit_size exceeding
315*ffe74184SMartin Algesten     // the hardlimit will be silently discarded.
set_limit_size(&self, limit: usize)316*ffe74184SMartin Algesten     pub async fn set_limit_size(&self, limit: usize) {
317*ffe74184SMartin Algesten         let mut b = self.buffer.lock().await;
318*ffe74184SMartin Algesten 
319*ffe74184SMartin Algesten         b.limit_size = limit
320*ffe74184SMartin Algesten     }
321*ffe74184SMartin Algesten }
322