xref: /webrtc/util/src/buffer/mod.rs (revision ffe74184)
1 #[cfg(test)]
2 mod buffer_test;
3 
4 use crate::error::{Error, Result};
5 
6 use std::sync::Arc;
7 use tokio::sync::{Mutex, Notify};
8 use tokio::time::{timeout, Duration};
9 
10 const MIN_SIZE: usize = 2048;
11 const CUTOFF_SIZE: usize = 128 * 1024;
12 const MAX_SIZE: usize = 4 * 1024 * 1024;
13 
14 /// Buffer allows writing packets to an intermediate buffer, which can then be read form.
15 /// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read.
16 #[derive(Debug)]
17 struct BufferInternal {
18     data: Vec<u8>,
19     head: usize,
20     tail: usize,
21 
22     closed: bool,
23     subs: bool,
24 
25     count: usize,
26     limit_count: usize,
27     limit_size: usize,
28 }
29 
30 impl BufferInternal {
31     /// available returns true if the buffer is large enough to fit a packet
32     /// of the given size, taking overhead into account.
available(&self, size: usize) -> bool33     fn available(&self, size: usize) -> bool {
34         let mut available = self.head as isize - self.tail as isize;
35         if available <= 0 {
36             available += self.data.len() as isize;
37         }
38         // we interpret head=tail as empty, so always keep a byte free
39         size as isize + 2 < available
40     }
41 
42     /// grow increases the size of the buffer.  If it returns nil, then the
43     /// buffer has been grown.  It returns ErrFull if hits a limit.
grow(&mut self) -> Result<()>44     fn grow(&mut self) -> Result<()> {
45         let mut newsize = if self.data.len() < CUTOFF_SIZE {
46             2 * self.data.len()
47         } else {
48             5 * self.data.len() / 4
49         };
50 
51         if newsize < MIN_SIZE {
52             newsize = MIN_SIZE
53         }
54         if (self.limit_size == 0/*|| sizeHardlimit*/) && newsize > MAX_SIZE {
55             newsize = MAX_SIZE
56         }
57 
58         // one byte slack
59         if self.limit_size > 0 && newsize > self.limit_size + 1 {
60             newsize = self.limit_size + 1
61         }
62 
63         if newsize <= self.data.len() {
64             return Err(Error::ErrBufferFull);
65         }
66 
67         let mut newdata: Vec<u8> = vec![0; newsize];
68 
69         let mut n;
70         if self.head <= self.tail {
71             // data was contiguous
72             n = self.tail - self.head;
73             newdata[..n].copy_from_slice(&self.data[self.head..self.tail]);
74         } else {
75             // data was discontiguous
76             n = self.data.len() - self.head;
77             newdata[..n].copy_from_slice(&self.data[self.head..]);
78             newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]);
79             n += self.tail;
80         }
81         self.head = 0;
82         self.tail = n;
83         self.data = newdata;
84 
85         Ok(())
86     }
87 
size(&self) -> usize88     fn size(&self) -> usize {
89         let mut size = self.tail as isize - self.head as isize;
90         if size < 0 {
91             size += self.data.len() as isize;
92         }
93         size as usize
94     }
95 }
96 
97 #[derive(Debug, Clone)]
98 pub struct Buffer {
99     buffer: Arc<Mutex<BufferInternal>>,
100     notify: Arc<Notify>,
101 }
102 
103 impl Buffer {
new(limit_count: usize, limit_size: usize) -> Self104     pub fn new(limit_count: usize, limit_size: usize) -> Self {
105         Buffer {
106             buffer: Arc::new(Mutex::new(BufferInternal {
107                 data: vec![],
108                 head: 0,
109                 tail: 0,
110 
111                 closed: false,
112                 subs: false,
113 
114                 count: 0,
115                 limit_count,
116                 limit_size,
117             })),
118             notify: Arc::new(Notify::new()),
119         }
120     }
121 
122     /// Write appends a copy of the packet data to the buffer.
123     /// Returns ErrFull if the packet doesn't fit.
124     /// Note that the packet size is limited to 65536 bytes since v0.11.0
125     /// due to the internal data structure.
write(&self, packet: &[u8]) -> Result<usize>126     pub async fn write(&self, packet: &[u8]) -> Result<usize> {
127         if packet.len() >= 0x10000 {
128             return Err(Error::ErrPacketTooBig);
129         }
130 
131         let mut b = self.buffer.lock().await;
132 
133         if b.closed {
134             return Err(Error::ErrBufferClosed);
135         }
136 
137         if (b.limit_count > 0 && b.count >= b.limit_count)
138             || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size)
139         {
140             return Err(Error::ErrBufferFull);
141         }
142 
143         // grow the buffer until the packet fits
144         while !b.available(packet.len()) {
145             b.grow()?;
146         }
147 
148         // store the length of the packet
149         let tail = b.tail;
150         b.data[tail] = (packet.len() >> 8) as u8;
151         b.tail += 1;
152         if b.tail >= b.data.len() {
153             b.tail = 0;
154         }
155 
156         let tail = b.tail;
157         b.data[tail] = packet.len() as u8;
158         b.tail += 1;
159         if b.tail >= b.data.len() {
160             b.tail = 0;
161         }
162 
163         // store the packet
164         let end = std::cmp::min(b.data.len(), b.tail + packet.len());
165         let n = end - b.tail;
166         let tail = b.tail;
167         b.data[tail..end].copy_from_slice(&packet[..n]);
168         b.tail += n;
169         if b.tail >= b.data.len() {
170             // we reached the end, wrap around
171             let m = packet.len() - n;
172             b.data[..m].copy_from_slice(&packet[n..]);
173             b.tail = m;
174         }
175         b.count += 1;
176 
177         if b.subs {
178             // we have other are waiting data
179             self.notify.notify_one();
180             b.subs = false;
181         }
182 
183         Ok(packet.len())
184     }
185 
186     // Read populates the given byte slice, returning the number of bytes read.
187     // Blocks until data is available or the buffer is closed.
188     // Returns io.ErrShortBuffer is the packet is too small to copy the Write.
189     // Returns io.EOF if the buffer is closed.
read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize>190     pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> {
191         loop {
192             {
193                 // use {} to let LockGuard RAII
194                 let mut b = self.buffer.lock().await;
195 
196                 if b.head != b.tail {
197                     // decode the packet size
198                     let n1 = b.data[b.head];
199                     b.head += 1;
200                     if b.head >= b.data.len() {
201                         b.head = 0;
202                     }
203                     let n2 = b.data[b.head];
204                     b.head += 1;
205                     if b.head >= b.data.len() {
206                         b.head = 0;
207                     }
208                     let count = ((n1 as usize) << 8) | n2 as usize;
209 
210                     // determine the number of bytes we'll actually copy
211                     let mut copied = count;
212                     if copied > packet.len() {
213                         copied = packet.len();
214                     }
215 
216                     // copy the data
217                     if b.head + copied < b.data.len() {
218                         packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]);
219                     } else {
220                         let k = b.data.len() - b.head;
221                         packet[..k].copy_from_slice(&b.data[b.head..]);
222                         packet[k..copied].copy_from_slice(&b.data[..copied - k]);
223                     }
224 
225                     // advance head, discarding any data that wasn't copied
226                     b.head += count;
227                     if b.head >= b.data.len() {
228                         b.head -= b.data.len();
229                     }
230 
231                     if b.head == b.tail {
232                         // the buffer is empty, reset to beginning
233                         // in order to improve cache locality.
234                         b.head = 0;
235                         b.tail = 0;
236                     }
237 
238                     b.count -= 1;
239 
240                     if copied < count {
241                         return Err(Error::ErrBufferShort);
242                     }
243                     return Ok(copied);
244                 } else {
245                     // Dont have data -> need wait
246                     b.subs = true;
247                 }
248 
249                 if b.closed {
250                     return Err(Error::ErrBufferClosed);
251                 }
252             }
253 
254             // Wait for signal.
255             if let Some(d) = duration {
256                 if timeout(d, self.notify.notified()).await.is_err() {
257                     return Err(Error::ErrTimeout);
258                 }
259             } else {
260                 self.notify.notified().await;
261             }
262         }
263     }
264 
265     // Close will unblock any readers and prevent future writes.
266     // Data in the buffer can still be read, returning io.EOF when fully depleted.
close(&self)267     pub async fn close(&self) {
268         // note: We don't use defer so we can close the notify channel after unlocking.
269         // This will unblock goroutines that can grab the lock immediately, instead of blocking again.
270         let mut b = self.buffer.lock().await;
271 
272         if b.closed {
273             return;
274         }
275 
276         b.closed = true;
277         self.notify.notify_waiters();
278     }
279 
is_closed(&self) -> bool280     pub async fn is_closed(&self) -> bool {
281         let b = self.buffer.lock().await;
282 
283         b.closed
284     }
285 
286     // Count returns the number of packets in the buffer.
count(&self) -> usize287     pub async fn count(&self) -> usize {
288         let b = self.buffer.lock().await;
289 
290         b.count
291     }
292 
293     // set_limit_count controls the maximum number of packets that can be buffered.
294     // Causes Write to return ErrFull when this limit is reached.
295     // A zero value will disable this limit.
set_limit_count(&self, limit: usize)296     pub async fn set_limit_count(&self, limit: usize) {
297         let mut b = self.buffer.lock().await;
298 
299         b.limit_count = limit
300     }
301 
302     // Size returns the total byte size of packets in the buffer.
size(&self) -> usize303     pub async fn size(&self) -> usize {
304         let b = self.buffer.lock().await;
305 
306         b.size()
307     }
308 
309     // set_limit_size controls the maximum number of bytes that can be buffered.
310     // Causes Write to return ErrFull when this limit is reached.
311     // A zero value means 4MB since v0.11.0.
312     //
313     // User can set packetioSizeHardlimit build tag to enable 4MB hardlimit.
314     // When packetioSizeHardlimit build tag is set, set_limit_size exceeding
315     // the hardlimit will be silently discarded.
set_limit_size(&self, limit: usize)316     pub async fn set_limit_size(&self, limit: usize) {
317         let mut b = self.buffer.lock().await;
318 
319         b.limit_size = limit
320     }
321 }
322