xref: /webrtc/sctp/src/stream/mod.rs (revision 97921129)
1 #[cfg(test)]
2 mod stream_test;
3 
4 use crate::association::AssociationState;
5 use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
6 use crate::error::{Error, Result};
7 use crate::queue::pending_queue::PendingQueue;
8 use crate::queue::reassembly_queue::ReassemblyQueue;
9 
10 use arc_swap::ArcSwapOption;
11 use bytes::Bytes;
12 use std::{
13     fmt,
14     future::Future,
15     io,
16     net::Shutdown,
17     pin::Pin,
18     sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering},
19     sync::Arc,
20     task::{Context, Poll},
21 };
22 use tokio::{
23     io::{AsyncRead, AsyncWrite, ReadBuf},
24     sync::{mpsc, Mutex, Notify},
25 };
26 
27 #[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
28 #[repr(C)]
29 pub enum ReliabilityType {
30     /// ReliabilityTypeReliable is used for reliable transmission
31     #[default]
32     Reliable = 0,
33     /// ReliabilityTypeRexmit is used for partial reliability by retransmission count
34     Rexmit = 1,
35     /// ReliabilityTypeTimed is used for partial reliability by retransmission duration
36     Timed = 2,
37 }
38 
39 impl fmt::Display for ReliabilityType {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result40     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41         let s = match *self {
42             ReliabilityType::Reliable => "Reliable",
43             ReliabilityType::Rexmit => "Rexmit",
44             ReliabilityType::Timed => "Timed",
45         };
46         write!(f, "{s}")
47     }
48 }
49 
50 impl From<u8> for ReliabilityType {
from(v: u8) -> ReliabilityType51     fn from(v: u8) -> ReliabilityType {
52         match v {
53             1 => ReliabilityType::Rexmit,
54             2 => ReliabilityType::Timed,
55             _ => ReliabilityType::Reliable,
56         }
57     }
58 }
59 
60 pub type OnBufferedAmountLowFn =
61     Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
62 
63 // TODO: benchmark performance between multiple Atomic+Mutex vs one Mutex<StreamInternal>
64 
65 /// Stream represents an SCTP stream
66 #[derive(Default)]
67 pub struct Stream {
68     pub(crate) max_payload_size: u32,
69     pub(crate) max_message_size: Arc<AtomicU32>, // clone from association
70     pub(crate) state: Arc<AtomicU8>,             // clone from association
71     pub(crate) awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
72     pub(crate) pending_queue: Arc<PendingQueue>,
73 
74     pub(crate) stream_identifier: u16,
75     pub(crate) default_payload_type: AtomicU32, //PayloadProtocolIdentifier,
76     pub(crate) reassembly_queue: Mutex<ReassemblyQueue>,
77     pub(crate) sequence_number: AtomicU16,
78     pub(crate) read_notifier: Notify,
79     pub(crate) read_shutdown: AtomicBool,
80     pub(crate) write_shutdown: AtomicBool,
81     pub(crate) unordered: AtomicBool,
82     pub(crate) reliability_type: AtomicU8, //ReliabilityType,
83     pub(crate) reliability_value: AtomicU32,
84     pub(crate) buffered_amount: AtomicUsize,
85     pub(crate) buffered_amount_low: AtomicUsize,
86     pub(crate) on_buffered_amount_low: ArcSwapOption<Mutex<OnBufferedAmountLowFn>>,
87     pub(crate) name: String,
88 }
89 
90 impl fmt::Debug for Stream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result91     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92         f.debug_struct("Stream")
93             .field("max_payload_size", &self.max_payload_size)
94             .field("max_message_size", &self.max_message_size)
95             .field("state", &self.state)
96             .field("awake_write_loop_ch", &self.awake_write_loop_ch)
97             .field("stream_identifier", &self.stream_identifier)
98             .field("default_payload_type", &self.default_payload_type)
99             .field("reassembly_queue", &self.reassembly_queue)
100             .field("sequence_number", &self.sequence_number)
101             .field("read_shutdown", &self.read_shutdown)
102             .field("write_shutdown", &self.write_shutdown)
103             .field("unordered", &self.unordered)
104             .field("reliability_type", &self.reliability_type)
105             .field("reliability_value", &self.reliability_value)
106             .field("buffered_amount", &self.buffered_amount)
107             .field("buffered_amount_low", &self.buffered_amount_low)
108             .field("name", &self.name)
109             .finish()
110     }
111 }
112 
113 impl Stream {
new( name: String, stream_identifier: u16, max_payload_size: u32, max_message_size: Arc<AtomicU32>, state: Arc<AtomicU8>, awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>, pending_queue: Arc<PendingQueue>, ) -> Self114     pub(crate) fn new(
115         name: String,
116         stream_identifier: u16,
117         max_payload_size: u32,
118         max_message_size: Arc<AtomicU32>,
119         state: Arc<AtomicU8>,
120         awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>,
121         pending_queue: Arc<PendingQueue>,
122     ) -> Self {
123         Stream {
124             max_payload_size,
125             max_message_size,
126             state,
127             awake_write_loop_ch,
128             pending_queue,
129 
130             stream_identifier,
131             default_payload_type: AtomicU32::new(0), //PayloadProtocolIdentifier::Unknown,
132             reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)),
133             sequence_number: AtomicU16::new(0),
134             read_notifier: Notify::new(),
135             read_shutdown: AtomicBool::new(false),
136             write_shutdown: AtomicBool::new(false),
137             unordered: AtomicBool::new(false),
138             reliability_type: AtomicU8::new(0), //ReliabilityType::Reliable,
139             reliability_value: AtomicU32::new(0),
140             buffered_amount: AtomicUsize::new(0),
141             buffered_amount_low: AtomicUsize::new(0),
142             on_buffered_amount_low: ArcSwapOption::empty(),
143             name,
144         }
145     }
146 
147     /// stream_identifier returns the Stream identifier associated to the stream.
stream_identifier(&self) -> u16148     pub fn stream_identifier(&self) -> u16 {
149         self.stream_identifier
150     }
151 
152     /// set_default_payload_type sets the default payload type used by write.
set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier)153     pub fn set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier) {
154         self.default_payload_type
155             .store(default_payload_type as u32, Ordering::SeqCst);
156     }
157 
158     /// set_reliability_params sets reliability parameters for this stream.
set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32)159     pub fn set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32) {
160         log::debug!(
161             "[{}] reliability params: ordered={} type={} value={}",
162             self.name,
163             !unordered,
164             rel_type,
165             rel_val
166         );
167         self.unordered.store(unordered, Ordering::SeqCst);
168         self.reliability_type
169             .store(rel_type as u8, Ordering::SeqCst);
170         self.reliability_value.store(rel_val, Ordering::SeqCst);
171     }
172 
173     /// Reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
174     ///
175     /// Returns `Error::ErrShortBuffer` if `p` is too short.
176     /// Returns `0` if the reading half of this stream is shutdown or it (the stream) was reset.
read(&self, p: &mut [u8]) -> Result<usize>177     pub async fn read(&self, p: &mut [u8]) -> Result<usize> {
178         let (n, _) = self.read_sctp(p).await?;
179         Ok(n)
180     }
181 
182     /// Reads a packet of len(p) bytes and returns the associated Payload Protocol Identifier.
183     ///
184     /// Returns `Error::ErrShortBuffer` if `p` is too short.
185     /// Returns `(0, PayloadProtocolIdentifier::Unknown)` if the reading half of this stream is shutdown or it (the stream) was reset.
read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)>186     pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> {
187         loop {
188             if self.read_shutdown.load(Ordering::SeqCst) {
189                 return Ok((0, PayloadProtocolIdentifier::Unknown));
190             }
191 
192             let result = {
193                 let mut reassembly_queue = self.reassembly_queue.lock().await;
194                 reassembly_queue.read(p)
195             };
196 
197             match result {
198                 Ok(_) | Err(Error::ErrShortBuffer) => return result,
199                 Err(_) => {
200                     // wait for the next chunk to become available
201                     self.read_notifier.notified().await;
202                 }
203             }
204         }
205     }
206 
handle_data(&self, pd: ChunkPayloadData)207     pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) {
208         let readable = {
209             let mut reassembly_queue = self.reassembly_queue.lock().await;
210             if reassembly_queue.push(pd) {
211                 let readable = reassembly_queue.is_readable();
212                 log::debug!("[{}] reassemblyQueue readable={}", self.name, readable);
213                 readable
214             } else {
215                 false
216             }
217         };
218 
219         if readable {
220             log::debug!("[{}] readNotifier.signal()", self.name);
221             self.read_notifier.notify_one();
222             log::debug!("[{}] readNotifier.signal() done", self.name);
223         }
224     }
225 
handle_forward_tsn_for_ordered(&self, ssn: u16)226     pub(crate) async fn handle_forward_tsn_for_ordered(&self, ssn: u16) {
227         if self.unordered.load(Ordering::SeqCst) {
228             return; // unordered chunks are handled by handleForwardUnordered method
229         }
230 
231         // Remove all chunks older than or equal to the new TSN from
232         // the reassembly_queue.
233         let readable = {
234             let mut reassembly_queue = self.reassembly_queue.lock().await;
235             reassembly_queue.forward_tsn_for_ordered(ssn);
236             reassembly_queue.is_readable()
237         };
238 
239         // Notify the reader asynchronously if there's a data chunk to read.
240         if readable {
241             self.read_notifier.notify_one();
242         }
243     }
244 
handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32)245     pub(crate) async fn handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32) {
246         if !self.unordered.load(Ordering::SeqCst) {
247             return; // ordered chunks are handled by handleForwardTSNOrdered method
248         }
249 
250         // Remove all chunks older than or equal to the new TSN from
251         // the reassembly_queue.
252         let readable = {
253             let mut reassembly_queue = self.reassembly_queue.lock().await;
254             reassembly_queue.forward_tsn_for_unordered(new_cumulative_tsn);
255             reassembly_queue.is_readable()
256         };
257 
258         // Notify the reader asynchronously if there's a data chunk to read.
259         if readable {
260             self.read_notifier.notify_one();
261         }
262     }
263 
264     /// Writes `p` to the DTLS connection with the default Payload Protocol Identifier.
265     ///
266     /// Returns an error if the write half of this stream is shutdown or `p` is too large.
write(&self, p: &Bytes) -> Result<usize>267     pub async fn write(&self, p: &Bytes) -> Result<usize> {
268         self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into())
269             .await
270     }
271 
272     /// Writes `p` to the DTLS connection with the given Payload Protocol Identifier.
273     ///
274     /// Returns an error if the write half of this stream is shutdown or `p` is too large.
write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize>275     pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
276         let chunks = self.prepare_write(p, ppi)?;
277         self.send_payload_data(chunks).await?;
278 
279         Ok(p.len())
280     }
281 
282     /// common stuff for write and try_write
prepare_write( &self, p: &Bytes, ppi: PayloadProtocolIdentifier, ) -> Result<Vec<ChunkPayloadData>>283     fn prepare_write(
284         &self,
285         p: &Bytes,
286         ppi: PayloadProtocolIdentifier,
287     ) -> Result<Vec<ChunkPayloadData>> {
288         if self.write_shutdown.load(Ordering::SeqCst) {
289             return Err(Error::ErrStreamClosed);
290         }
291 
292         if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize {
293             return Err(Error::ErrOutboundPacketTooLarge);
294         }
295 
296         let state: AssociationState = self.state.load(Ordering::SeqCst).into();
297         match state {
298             AssociationState::ShutdownSent
299             | AssociationState::ShutdownAckSent
300             | AssociationState::ShutdownPending
301             | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
302             _ => {}
303         };
304 
305         Ok(self.packetize(p, ppi))
306     }
307 
packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData>308     fn packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData> {
309         let mut i = 0;
310         let mut remaining = raw.len();
311 
312         // From draft-ietf-rtcweb-data-protocol-09, section 6:
313         //   All Data Channel Establishment Protocol messages MUST be sent using
314         //   ordered delivery and reliable transmission.
315         let unordered =
316             ppi != PayloadProtocolIdentifier::Dcep && self.unordered.load(Ordering::SeqCst);
317 
318         let mut chunks = vec![];
319 
320         let head_abandoned = Arc::new(AtomicBool::new(false));
321         let head_all_inflight = Arc::new(AtomicBool::new(false));
322         while remaining != 0 {
323             let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); //self.association.max_payload_size
324 
325             // Copy the userdata since we'll have to store it until acked
326             // and the caller may re-use the buffer in the mean time
327             let user_data = raw.slice(i..i + fragment_size);
328 
329             let chunk = ChunkPayloadData {
330                 stream_identifier: self.stream_identifier,
331                 user_data,
332                 unordered,
333                 beginning_fragment: i == 0,
334                 ending_fragment: remaining - fragment_size == 0,
335                 immediate_sack: false,
336                 payload_type: ppi,
337                 stream_sequence_number: self.sequence_number.load(Ordering::SeqCst),
338                 abandoned: head_abandoned.clone(), // all fragmented chunks use the same abandoned
339                 all_inflight: head_all_inflight.clone(), // all fragmented chunks use the same all_inflight
340                 ..Default::default()
341             };
342 
343             chunks.push(chunk);
344 
345             remaining -= fragment_size;
346             i += fragment_size;
347         }
348 
349         // RFC 4960 Sec 6.6
350         // Note: When transmitting ordered and unordered data, an endpoint does
351         // not increment its Stream Sequence Number when transmitting a DATA
352         // chunk with U flag set to 1.
353         if !unordered {
354             self.sequence_number.fetch_add(1, Ordering::SeqCst);
355         }
356 
357         let old_value = self.buffered_amount.fetch_add(raw.len(), Ordering::SeqCst);
358         log::trace!("[{}] bufferedAmount = {}", self.name, old_value + raw.len());
359 
360         chunks
361     }
362 
363     /// Closes both read and write halves of this stream.
364     ///
365     /// Use [`Stream::shutdown`] instead.
366     #[deprecated]
close(&self) -> Result<()>367     pub async fn close(&self) -> Result<()> {
368         self.shutdown(Shutdown::Both).await
369     }
370 
371     /// Shuts down the read, write, or both halves of this stream.
372     ///
373     /// This function will cause all pending and future I/O on the specified portions to return
374     /// immediately with an appropriate value (see the documentation of [`Shutdown`]).
375     ///
376     /// Resets the stream when both halves of this stream are shutdown.
shutdown(&self, how: Shutdown) -> Result<()>377     pub async fn shutdown(&self, how: Shutdown) -> Result<()> {
378         if self.read_shutdown.load(Ordering::SeqCst) && self.write_shutdown.load(Ordering::SeqCst) {
379             return Ok(());
380         }
381 
382         if how == Shutdown::Write || how == Shutdown::Both {
383             self.write_shutdown.store(true, Ordering::SeqCst);
384         }
385 
386         if (how == Shutdown::Read || how == Shutdown::Both)
387             && !self.read_shutdown.swap(true, Ordering::SeqCst)
388         {
389             self.read_notifier.notify_waiters();
390         }
391 
392         if how == Shutdown::Both
393             || (self.read_shutdown.load(Ordering::SeqCst)
394                 && self.write_shutdown.load(Ordering::SeqCst))
395         {
396             // Reset the stream
397             // https://tools.ietf.org/html/rfc6525
398             self.send_reset_request(self.stream_identifier).await?;
399         }
400 
401         Ok(())
402     }
403 
404     /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream.
buffered_amount(&self) -> usize405     pub fn buffered_amount(&self) -> usize {
406         self.buffered_amount.load(Ordering::SeqCst)
407     }
408 
409     /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is
410     /// considered "low." Defaults to 0.
buffered_amount_low_threshold(&self) -> usize411     pub fn buffered_amount_low_threshold(&self) -> usize {
412         self.buffered_amount_low.load(Ordering::SeqCst)
413     }
414 
415     /// set_buffered_amount_low_threshold is used to update the threshold.
416     /// See buffered_amount_low_threshold().
set_buffered_amount_low_threshold(&self, th: usize)417     pub fn set_buffered_amount_low_threshold(&self, th: usize) {
418         self.buffered_amount_low.store(th, Ordering::SeqCst);
419     }
420 
421     /// on_buffered_amount_low sets the callback handler which would be called when the number of
422     /// bytes of outgoing data buffered is lower than the threshold.
on_buffered_amount_low(&self, f: OnBufferedAmountLowFn)423     pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
424         self.on_buffered_amount_low
425             .store(Some(Arc::new(Mutex::new(f))));
426     }
427 
428     /// This method is called by association's read_loop (go-)routine to notify this stream
429     /// of the specified amount of outgoing data has been delivered to the peer.
on_buffer_released(&self, n_bytes_released: i64)430     pub(crate) async fn on_buffer_released(&self, n_bytes_released: i64) {
431         if n_bytes_released <= 0 {
432             return;
433         }
434 
435         let from_amount = self.buffered_amount.load(Ordering::SeqCst);
436         let new_amount = if from_amount < n_bytes_released as usize {
437             self.buffered_amount.store(0, Ordering::SeqCst);
438             log::error!(
439                 "[{}] released buffer size {} should be <= {}",
440                 self.name,
441                 n_bytes_released,
442                 0,
443             );
444             0
445         } else {
446             self.buffered_amount
447                 .fetch_sub(n_bytes_released as usize, Ordering::SeqCst);
448 
449             from_amount - n_bytes_released as usize
450         };
451 
452         let buffered_amount_low = self.buffered_amount_low.load(Ordering::SeqCst);
453 
454         log::trace!(
455             "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}",
456             self.name,
457             new_amount,
458             from_amount,
459             buffered_amount_low,
460         );
461 
462         if from_amount > buffered_amount_low && new_amount <= buffered_amount_low {
463             if let Some(handler) = &*self.on_buffered_amount_low.load() {
464                 let mut f = handler.lock().await;
465                 f().await;
466             }
467         }
468     }
469 
470     /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to
471     /// be read (once chunk is complete).
get_num_bytes_in_reassembly_queue(&self) -> usize472     pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
473         // No lock is required as it reads the size with atomic load function.
474         let reassembly_queue = self.reassembly_queue.lock().await;
475         reassembly_queue.get_num_bytes()
476     }
477 
478     /// get_state atomically returns the state of the Association.
get_state(&self) -> AssociationState479     fn get_state(&self) -> AssociationState {
480         self.state.load(Ordering::SeqCst).into()
481     }
482 
awake_write_loop(&self)483     fn awake_write_loop(&self) {
484         //log::debug!("[{}] awake_write_loop_ch.notify_one", self.name);
485         if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch {
486             let _ = awake_write_loop_ch.try_send(());
487         }
488     }
489 
send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()>490     async fn send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()> {
491         let state = self.get_state();
492         if state != AssociationState::Established {
493             return Err(Error::ErrPayloadDataStateNotExist);
494         }
495 
496         // NOTE: append is used here instead of push in order to prevent chunks interlacing.
497         self.pending_queue.append(chunks).await;
498 
499         self.awake_write_loop();
500         Ok(())
501     }
502 
send_reset_request(&self, stream_identifier: u16) -> Result<()>503     async fn send_reset_request(&self, stream_identifier: u16) -> Result<()> {
504         let state = self.get_state();
505         if state != AssociationState::Established {
506             return Err(Error::ErrResetPacketInStateNotExist);
507         }
508 
509         // Create DATA chunk which only contains valid stream identifier with
510         // nil userData and use it as a EOS from the stream.
511         let c = ChunkPayloadData {
512             stream_identifier,
513             beginning_fragment: true,
514             ending_fragment: true,
515             user_data: Bytes::new(),
516             ..Default::default()
517         };
518 
519         self.pending_queue.push(c).await;
520 
521         self.awake_write_loop();
522         Ok(())
523     }
524 }
525 
526 /// Default capacity of the temporary read buffer used by [`PollStream`].
527 const DEFAULT_READ_BUF_SIZE: usize = 8192;
528 
529 /// State of the read `Future` in [`PollStream`].
530 enum ReadFut {
531     /// Nothing in progress.
532     Idle,
533     /// Reading data from the underlying stream.
534     Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
535     /// Finished reading, but there's unread data in the temporary buffer.
536     RemainingData(Vec<u8>),
537 }
538 
539 enum ShutdownFut {
540     /// Nothing in progress.
541     Idle,
542     /// Reading data from the underlying stream.
543     ShuttingDown(Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>>),
544     /// Shutdown future has run
545     Done,
546     Errored(crate::error::Error),
547 }
548 
549 impl ReadFut {
550     /// Gets a mutable reference to the future stored inside `Reading(future)`.
551     ///
552     /// # Panics
553     ///
554     /// Panics if `ReadFut` variant is not `Reading`.
get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>555     fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
556         match self {
557             ReadFut::Reading(ref mut fut) => fut,
558             _ => panic!("expected ReadFut to be Reading"),
559         }
560     }
561 }
562 
563 impl ShutdownFut {
564     /// Gets a mutable reference to the future stored inside `ShuttingDown(future)`.
565     ///
566     /// # Panics
567     ///
568     /// Panics if `ShutdownFut` variant is not `ShuttingDown`.
get_shutting_down_mut( &mut self, ) -> &mut Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>>569     fn get_shutting_down_mut(
570         &mut self,
571     ) -> &mut Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>> {
572         match self {
573             ShutdownFut::ShuttingDown(ref mut fut) => fut,
574             _ => panic!("expected ShutdownFut to be ShuttingDown"),
575         }
576     }
577 }
578 
579 /// A wrapper around around [`Stream`], which implements [`AsyncRead`] and
580 /// [`AsyncWrite`].
581 ///
582 /// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
583 /// additional overhead.
584 pub struct PollStream {
585     stream: Arc<Stream>,
586 
587     read_fut: ReadFut,
588     write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>>>>>,
589     shutdown_fut: ShutdownFut,
590 
591     read_buf_cap: usize,
592 }
593 
594 impl PollStream {
595     /// Constructs a new `PollStream`.
596     ///
597     /// # Examples
598     ///
599     /// ```
600     /// use webrtc_sctp::stream::{Stream, PollStream};
601     /// use std::sync::Arc;
602     ///
603     /// let stream = Arc::new(Stream::default());
604     /// let poll_stream = PollStream::new(stream);
605     /// ```
new(stream: Arc<Stream>) -> Self606     pub fn new(stream: Arc<Stream>) -> Self {
607         Self {
608             stream,
609             read_fut: ReadFut::Idle,
610             write_fut: None,
611             shutdown_fut: ShutdownFut::Idle,
612             read_buf_cap: DEFAULT_READ_BUF_SIZE,
613         }
614     }
615 
616     /// Get back the inner stream.
617     #[must_use]
into_inner(self) -> Arc<Stream>618     pub fn into_inner(self) -> Arc<Stream> {
619         self.stream
620     }
621 
622     /// Obtain a clone of the inner stream.
623     #[must_use]
clone_inner(&self) -> Arc<Stream>624     pub fn clone_inner(&self) -> Arc<Stream> {
625         self.stream.clone()
626     }
627 
628     /// stream_identifier returns the Stream identifier associated to the stream.
stream_identifier(&self) -> u16629     pub fn stream_identifier(&self) -> u16 {
630         self.stream.stream_identifier
631     }
632 
633     /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream.
buffered_amount(&self) -> usize634     pub fn buffered_amount(&self) -> usize {
635         self.stream.buffered_amount.load(Ordering::SeqCst)
636     }
637 
638     /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is
639     /// considered "low." Defaults to 0.
buffered_amount_low_threshold(&self) -> usize640     pub fn buffered_amount_low_threshold(&self) -> usize {
641         self.stream.buffered_amount_low.load(Ordering::SeqCst)
642     }
643 
644     /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to
645     /// be read (once chunk is complete).
get_num_bytes_in_reassembly_queue(&self) -> usize646     pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
647         // No lock is required as it reads the size with atomic load function.
648         let reassembly_queue = self.stream.reassembly_queue.lock().await;
649         reassembly_queue.get_num_bytes()
650     }
651 
652     /// Set the capacity of the temporary read buffer (default: 8192).
set_read_buf_capacity(&mut self, capacity: usize)653     pub fn set_read_buf_capacity(&mut self, capacity: usize) {
654         self.read_buf_cap = capacity
655     }
656 }
657 
658 impl AsyncRead for PollStream {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>659     fn poll_read(
660         mut self: Pin<&mut Self>,
661         cx: &mut Context<'_>,
662         buf: &mut ReadBuf<'_>,
663     ) -> Poll<io::Result<()>> {
664         if buf.remaining() == 0 {
665             return Poll::Ready(Ok(()));
666         }
667 
668         let fut = match self.read_fut {
669             ReadFut::Idle => {
670                 // read into a temporary buffer because `buf` has an unonymous lifetime, which can
671                 // be shorter than the lifetime of `read_fut`.
672                 let stream = self.stream.clone();
673                 let mut temp_buf = vec![0; self.read_buf_cap];
674                 self.read_fut = ReadFut::Reading(Box::pin(async move {
675                     stream.read(temp_buf.as_mut_slice()).await.map(|n| {
676                         temp_buf.truncate(n);
677                         temp_buf
678                     })
679                 }));
680                 self.read_fut.get_reading_mut()
681             }
682             ReadFut::Reading(ref mut fut) => fut,
683             ReadFut::RemainingData(ref mut data) => {
684                 let remaining = buf.remaining();
685                 let len = std::cmp::min(data.len(), remaining);
686                 buf.put_slice(&data[..len]);
687                 if data.len() > remaining {
688                     // ReadFut remains to be RemainingData
689                     data.drain(0..len);
690                 } else {
691                     self.read_fut = ReadFut::Idle;
692                 }
693                 return Poll::Ready(Ok(()));
694             }
695         };
696 
697         loop {
698             match fut.as_mut().poll(cx) {
699                 Poll::Pending => return Poll::Pending,
700                 // retry immediately upon empty data or incomplete chunks
701                 // since there's no way to setup a waker.
702                 Poll::Ready(Err(Error::ErrTryAgain)) => {}
703                 // EOF has been reached => don't touch buf and just return Ok
704                 Poll::Ready(Err(Error::ErrEof)) => {
705                     self.read_fut = ReadFut::Idle;
706                     return Poll::Ready(Ok(()));
707                 }
708                 Poll::Ready(Err(e)) => {
709                     self.read_fut = ReadFut::Idle;
710                     return Poll::Ready(Err(e.into()));
711                 }
712                 Poll::Ready(Ok(mut temp_buf)) => {
713                     let remaining = buf.remaining();
714                     let len = std::cmp::min(temp_buf.len(), remaining);
715                     buf.put_slice(&temp_buf[..len]);
716                     if temp_buf.len() > remaining {
717                         temp_buf.drain(0..len);
718                         self.read_fut = ReadFut::RemainingData(temp_buf);
719                     } else {
720                         self.read_fut = ReadFut::Idle;
721                     }
722                     return Poll::Ready(Ok(()));
723                 }
724             }
725         }
726     }
727 }
728 
729 impl AsyncWrite for PollStream {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>730     fn poll_write(
731         mut self: Pin<&mut Self>,
732         cx: &mut Context<'_>,
733         buf: &[u8],
734     ) -> Poll<io::Result<usize>> {
735         if buf.is_empty() {
736             return Poll::Ready(Ok(0));
737         }
738 
739         if let Some(fut) = self.write_fut.as_mut() {
740             match fut.as_mut().poll(cx) {
741                 Poll::Pending => Poll::Pending,
742                 Poll::Ready(Err(e)) => {
743                     let stream = self.stream.clone();
744                     let bytes = Bytes::copy_from_slice(buf);
745                     self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await }));
746                     Poll::Ready(Err(e.into()))
747                 }
748                 // Given the data is buffered, it's okay to ignore the number of written bytes.
749                 //
750                 // TODO: In the long term, `stream.write` should be made sync. Then we could
751                 // remove the whole `if` condition and just call `stream.write`.
752                 Poll::Ready(Ok(_)) => {
753                     let stream = self.stream.clone();
754                     let bytes = Bytes::copy_from_slice(buf);
755                     self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await }));
756                     Poll::Ready(Ok(buf.len()))
757                 }
758             }
759         } else {
760             let stream = self.stream.clone();
761             let bytes = Bytes::copy_from_slice(buf);
762             let fut = self
763                 .write_fut
764                 .insert(Box::pin(async move { stream.write(&bytes).await }));
765 
766             match fut.as_mut().poll(cx) {
767                 // If it's the first time we're polling the future, `Poll::Pending` can't be
768                 // returned because that would mean the `PollStream` is not ready for writing. And
769                 // this is not true since we've just created a future, which is going to write the
770                 // buf to the underlying stream.
771                 //
772                 // It's okay to return `Poll::Ready` if the data is buffered (this is what the
773                 // buffered writer and `File` do).
774                 Poll::Pending => Poll::Ready(Ok(buf.len())),
775                 Poll::Ready(Err(e)) => {
776                     self.write_fut = None;
777                     Poll::Ready(Err(e.into()))
778                 }
779                 Poll::Ready(Ok(n)) => {
780                     self.write_fut = None;
781                     Poll::Ready(Ok(n))
782                 }
783             }
784         }
785     }
786 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>787     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
788         match self.write_fut.as_mut() {
789             Some(fut) => match fut.as_mut().poll(cx) {
790                 Poll::Pending => Poll::Pending,
791                 Poll::Ready(Err(e)) => {
792                     self.write_fut = None;
793                     Poll::Ready(Err(e.into()))
794                 }
795                 Poll::Ready(Ok(_)) => {
796                     self.write_fut = None;
797                     Poll::Ready(Ok(()))
798                 }
799             },
800             None => Poll::Ready(Ok(())),
801         }
802     }
803 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>804     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
805         match self.as_mut().poll_flush(cx) {
806             Poll::Pending => return Poll::Pending,
807             Poll::Ready(_) => {}
808         }
809         let fut = match self.shutdown_fut {
810             ShutdownFut::Done => return Poll::Ready(Ok(())),
811             ShutdownFut::Errored(ref err) => return Poll::Ready(Err(err.clone().into())),
812             ShutdownFut::ShuttingDown(ref mut fut) => fut,
813             ShutdownFut::Idle => {
814                 let stream = self.stream.clone();
815                 self.shutdown_fut = ShutdownFut::ShuttingDown(Box::pin(async move {
816                     stream.shutdown(Shutdown::Write).await
817                 }));
818                 self.shutdown_fut.get_shutting_down_mut()
819             }
820         };
821 
822         match fut.as_mut().poll(cx) {
823             Poll::Pending => Poll::Pending,
824             Poll::Ready(Err(e)) => {
825                 self.shutdown_fut = ShutdownFut::Errored(e.clone());
826                 Poll::Ready(Err(e.into()))
827             }
828             Poll::Ready(Ok(_)) => {
829                 self.shutdown_fut = ShutdownFut::Done;
830                 Poll::Ready(Ok(()))
831             }
832         }
833     }
834 }
835 
836 impl Clone for PollStream {
clone(&self) -> PollStream837     fn clone(&self) -> PollStream {
838         PollStream::new(self.clone_inner())
839     }
840 }
841 
842 impl fmt::Debug for PollStream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result843     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
844         f.debug_struct("PollStream")
845             .field("stream", &self.stream)
846             .field("read_buf_cap", &self.read_buf_cap)
847             .finish()
848     }
849 }
850 
851 impl AsRef<Stream> for PollStream {
as_ref(&self) -> &Stream852     fn as_ref(&self) -> &Stream {
853         &self.stream
854     }
855 }
856