1 #[cfg(test)] 2 mod stream_test; 3 4 use crate::association::AssociationState; 5 use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier}; 6 use crate::error::{Error, Result}; 7 use crate::queue::pending_queue::PendingQueue; 8 use crate::queue::reassembly_queue::ReassemblyQueue; 9 10 use arc_swap::ArcSwapOption; 11 use bytes::Bytes; 12 use std::{ 13 fmt, 14 future::Future, 15 io, 16 net::Shutdown, 17 pin::Pin, 18 sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering}, 19 sync::Arc, 20 task::{Context, Poll}, 21 }; 22 use tokio::{ 23 io::{AsyncRead, AsyncWrite, ReadBuf}, 24 sync::{mpsc, Mutex, Notify}, 25 }; 26 27 #[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] 28 #[repr(C)] 29 pub enum ReliabilityType { 30 /// ReliabilityTypeReliable is used for reliable transmission 31 #[default] 32 Reliable = 0, 33 /// ReliabilityTypeRexmit is used for partial reliability by retransmission count 34 Rexmit = 1, 35 /// ReliabilityTypeTimed is used for partial reliability by retransmission duration 36 Timed = 2, 37 } 38 39 impl fmt::Display for ReliabilityType { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 41 let s = match *self { 42 ReliabilityType::Reliable => "Reliable", 43 ReliabilityType::Rexmit => "Rexmit", 44 ReliabilityType::Timed => "Timed", 45 }; 46 write!(f, "{s}") 47 } 48 } 49 50 impl From<u8> for ReliabilityType { from(v: u8) -> ReliabilityType51 fn from(v: u8) -> ReliabilityType { 52 match v { 53 1 => ReliabilityType::Rexmit, 54 2 => ReliabilityType::Timed, 55 _ => ReliabilityType::Reliable, 56 } 57 } 58 } 59 60 pub type OnBufferedAmountLowFn = 61 Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>; 62 63 // TODO: benchmark performance between multiple Atomic+Mutex vs one Mutex<StreamInternal> 64 65 /// Stream represents an SCTP stream 66 #[derive(Default)] 67 pub struct Stream { 68 pub(crate) max_payload_size: u32, 69 pub(crate) max_message_size: Arc<AtomicU32>, // clone from association 70 pub(crate) state: Arc<AtomicU8>, // clone from association 71 pub(crate) awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>, 72 pub(crate) pending_queue: Arc<PendingQueue>, 73 74 pub(crate) stream_identifier: u16, 75 pub(crate) default_payload_type: AtomicU32, //PayloadProtocolIdentifier, 76 pub(crate) reassembly_queue: Mutex<ReassemblyQueue>, 77 pub(crate) sequence_number: AtomicU16, 78 pub(crate) read_notifier: Notify, 79 pub(crate) read_shutdown: AtomicBool, 80 pub(crate) write_shutdown: AtomicBool, 81 pub(crate) unordered: AtomicBool, 82 pub(crate) reliability_type: AtomicU8, //ReliabilityType, 83 pub(crate) reliability_value: AtomicU32, 84 pub(crate) buffered_amount: AtomicUsize, 85 pub(crate) buffered_amount_low: AtomicUsize, 86 pub(crate) on_buffered_amount_low: ArcSwapOption<Mutex<OnBufferedAmountLowFn>>, 87 pub(crate) name: String, 88 } 89 90 impl fmt::Debug for Stream { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 92 f.debug_struct("Stream") 93 .field("max_payload_size", &self.max_payload_size) 94 .field("max_message_size", &self.max_message_size) 95 .field("state", &self.state) 96 .field("awake_write_loop_ch", &self.awake_write_loop_ch) 97 .field("stream_identifier", &self.stream_identifier) 98 .field("default_payload_type", &self.default_payload_type) 99 .field("reassembly_queue", &self.reassembly_queue) 100 .field("sequence_number", &self.sequence_number) 101 .field("read_shutdown", &self.read_shutdown) 102 .field("write_shutdown", &self.write_shutdown) 103 .field("unordered", &self.unordered) 104 .field("reliability_type", &self.reliability_type) 105 .field("reliability_value", &self.reliability_value) 106 .field("buffered_amount", &self.buffered_amount) 107 .field("buffered_amount_low", &self.buffered_amount_low) 108 .field("name", &self.name) 109 .finish() 110 } 111 } 112 113 impl Stream { new( name: String, stream_identifier: u16, max_payload_size: u32, max_message_size: Arc<AtomicU32>, state: Arc<AtomicU8>, awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>, pending_queue: Arc<PendingQueue>, ) -> Self114 pub(crate) fn new( 115 name: String, 116 stream_identifier: u16, 117 max_payload_size: u32, 118 max_message_size: Arc<AtomicU32>, 119 state: Arc<AtomicU8>, 120 awake_write_loop_ch: Option<Arc<mpsc::Sender<()>>>, 121 pending_queue: Arc<PendingQueue>, 122 ) -> Self { 123 Stream { 124 max_payload_size, 125 max_message_size, 126 state, 127 awake_write_loop_ch, 128 pending_queue, 129 130 stream_identifier, 131 default_payload_type: AtomicU32::new(0), //PayloadProtocolIdentifier::Unknown, 132 reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)), 133 sequence_number: AtomicU16::new(0), 134 read_notifier: Notify::new(), 135 read_shutdown: AtomicBool::new(false), 136 write_shutdown: AtomicBool::new(false), 137 unordered: AtomicBool::new(false), 138 reliability_type: AtomicU8::new(0), //ReliabilityType::Reliable, 139 reliability_value: AtomicU32::new(0), 140 buffered_amount: AtomicUsize::new(0), 141 buffered_amount_low: AtomicUsize::new(0), 142 on_buffered_amount_low: ArcSwapOption::empty(), 143 name, 144 } 145 } 146 147 /// stream_identifier returns the Stream identifier associated to the stream. stream_identifier(&self) -> u16148 pub fn stream_identifier(&self) -> u16 { 149 self.stream_identifier 150 } 151 152 /// set_default_payload_type sets the default payload type used by write. set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier)153 pub fn set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier) { 154 self.default_payload_type 155 .store(default_payload_type as u32, Ordering::SeqCst); 156 } 157 158 /// set_reliability_params sets reliability parameters for this stream. set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32)159 pub fn set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32) { 160 log::debug!( 161 "[{}] reliability params: ordered={} type={} value={}", 162 self.name, 163 !unordered, 164 rel_type, 165 rel_val 166 ); 167 self.unordered.store(unordered, Ordering::SeqCst); 168 self.reliability_type 169 .store(rel_type as u8, Ordering::SeqCst); 170 self.reliability_value.store(rel_val, Ordering::SeqCst); 171 } 172 173 /// Reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. 174 /// 175 /// Returns `Error::ErrShortBuffer` if `p` is too short. 176 /// Returns `0` if the reading half of this stream is shutdown or it (the stream) was reset. read(&self, p: &mut [u8]) -> Result<usize>177 pub async fn read(&self, p: &mut [u8]) -> Result<usize> { 178 let (n, _) = self.read_sctp(p).await?; 179 Ok(n) 180 } 181 182 /// Reads a packet of len(p) bytes and returns the associated Payload Protocol Identifier. 183 /// 184 /// Returns `Error::ErrShortBuffer` if `p` is too short. 185 /// Returns `(0, PayloadProtocolIdentifier::Unknown)` if the reading half of this stream is shutdown or it (the stream) was reset. read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)>186 pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> { 187 loop { 188 if self.read_shutdown.load(Ordering::SeqCst) { 189 return Ok((0, PayloadProtocolIdentifier::Unknown)); 190 } 191 192 let result = { 193 let mut reassembly_queue = self.reassembly_queue.lock().await; 194 reassembly_queue.read(p) 195 }; 196 197 match result { 198 Ok(_) | Err(Error::ErrShortBuffer) => return result, 199 Err(_) => { 200 // wait for the next chunk to become available 201 self.read_notifier.notified().await; 202 } 203 } 204 } 205 } 206 handle_data(&self, pd: ChunkPayloadData)207 pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) { 208 let readable = { 209 let mut reassembly_queue = self.reassembly_queue.lock().await; 210 if reassembly_queue.push(pd) { 211 let readable = reassembly_queue.is_readable(); 212 log::debug!("[{}] reassemblyQueue readable={}", self.name, readable); 213 readable 214 } else { 215 false 216 } 217 }; 218 219 if readable { 220 log::debug!("[{}] readNotifier.signal()", self.name); 221 self.read_notifier.notify_one(); 222 log::debug!("[{}] readNotifier.signal() done", self.name); 223 } 224 } 225 handle_forward_tsn_for_ordered(&self, ssn: u16)226 pub(crate) async fn handle_forward_tsn_for_ordered(&self, ssn: u16) { 227 if self.unordered.load(Ordering::SeqCst) { 228 return; // unordered chunks are handled by handleForwardUnordered method 229 } 230 231 // Remove all chunks older than or equal to the new TSN from 232 // the reassembly_queue. 233 let readable = { 234 let mut reassembly_queue = self.reassembly_queue.lock().await; 235 reassembly_queue.forward_tsn_for_ordered(ssn); 236 reassembly_queue.is_readable() 237 }; 238 239 // Notify the reader asynchronously if there's a data chunk to read. 240 if readable { 241 self.read_notifier.notify_one(); 242 } 243 } 244 handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32)245 pub(crate) async fn handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32) { 246 if !self.unordered.load(Ordering::SeqCst) { 247 return; // ordered chunks are handled by handleForwardTSNOrdered method 248 } 249 250 // Remove all chunks older than or equal to the new TSN from 251 // the reassembly_queue. 252 let readable = { 253 let mut reassembly_queue = self.reassembly_queue.lock().await; 254 reassembly_queue.forward_tsn_for_unordered(new_cumulative_tsn); 255 reassembly_queue.is_readable() 256 }; 257 258 // Notify the reader asynchronously if there's a data chunk to read. 259 if readable { 260 self.read_notifier.notify_one(); 261 } 262 } 263 264 /// Writes `p` to the DTLS connection with the default Payload Protocol Identifier. 265 /// 266 /// Returns an error if the write half of this stream is shutdown or `p` is too large. write(&self, p: &Bytes) -> Result<usize>267 pub async fn write(&self, p: &Bytes) -> Result<usize> { 268 self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into()) 269 .await 270 } 271 272 /// Writes `p` to the DTLS connection with the given Payload Protocol Identifier. 273 /// 274 /// Returns an error if the write half of this stream is shutdown or `p` is too large. write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize>275 pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> { 276 let chunks = self.prepare_write(p, ppi)?; 277 self.send_payload_data(chunks).await?; 278 279 Ok(p.len()) 280 } 281 282 /// common stuff for write and try_write prepare_write( &self, p: &Bytes, ppi: PayloadProtocolIdentifier, ) -> Result<Vec<ChunkPayloadData>>283 fn prepare_write( 284 &self, 285 p: &Bytes, 286 ppi: PayloadProtocolIdentifier, 287 ) -> Result<Vec<ChunkPayloadData>> { 288 if self.write_shutdown.load(Ordering::SeqCst) { 289 return Err(Error::ErrStreamClosed); 290 } 291 292 if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize { 293 return Err(Error::ErrOutboundPacketTooLarge); 294 } 295 296 let state: AssociationState = self.state.load(Ordering::SeqCst).into(); 297 match state { 298 AssociationState::ShutdownSent 299 | AssociationState::ShutdownAckSent 300 | AssociationState::ShutdownPending 301 | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed), 302 _ => {} 303 }; 304 305 Ok(self.packetize(p, ppi)) 306 } 307 packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData>308 fn packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData> { 309 let mut i = 0; 310 let mut remaining = raw.len(); 311 312 // From draft-ietf-rtcweb-data-protocol-09, section 6: 313 // All Data Channel Establishment Protocol messages MUST be sent using 314 // ordered delivery and reliable transmission. 315 let unordered = 316 ppi != PayloadProtocolIdentifier::Dcep && self.unordered.load(Ordering::SeqCst); 317 318 let mut chunks = vec![]; 319 320 let head_abandoned = Arc::new(AtomicBool::new(false)); 321 let head_all_inflight = Arc::new(AtomicBool::new(false)); 322 while remaining != 0 { 323 let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); //self.association.max_payload_size 324 325 // Copy the userdata since we'll have to store it until acked 326 // and the caller may re-use the buffer in the mean time 327 let user_data = raw.slice(i..i + fragment_size); 328 329 let chunk = ChunkPayloadData { 330 stream_identifier: self.stream_identifier, 331 user_data, 332 unordered, 333 beginning_fragment: i == 0, 334 ending_fragment: remaining - fragment_size == 0, 335 immediate_sack: false, 336 payload_type: ppi, 337 stream_sequence_number: self.sequence_number.load(Ordering::SeqCst), 338 abandoned: head_abandoned.clone(), // all fragmented chunks use the same abandoned 339 all_inflight: head_all_inflight.clone(), // all fragmented chunks use the same all_inflight 340 ..Default::default() 341 }; 342 343 chunks.push(chunk); 344 345 remaining -= fragment_size; 346 i += fragment_size; 347 } 348 349 // RFC 4960 Sec 6.6 350 // Note: When transmitting ordered and unordered data, an endpoint does 351 // not increment its Stream Sequence Number when transmitting a DATA 352 // chunk with U flag set to 1. 353 if !unordered { 354 self.sequence_number.fetch_add(1, Ordering::SeqCst); 355 } 356 357 let old_value = self.buffered_amount.fetch_add(raw.len(), Ordering::SeqCst); 358 log::trace!("[{}] bufferedAmount = {}", self.name, old_value + raw.len()); 359 360 chunks 361 } 362 363 /// Closes both read and write halves of this stream. 364 /// 365 /// Use [`Stream::shutdown`] instead. 366 #[deprecated] close(&self) -> Result<()>367 pub async fn close(&self) -> Result<()> { 368 self.shutdown(Shutdown::Both).await 369 } 370 371 /// Shuts down the read, write, or both halves of this stream. 372 /// 373 /// This function will cause all pending and future I/O on the specified portions to return 374 /// immediately with an appropriate value (see the documentation of [`Shutdown`]). 375 /// 376 /// Resets the stream when both halves of this stream are shutdown. shutdown(&self, how: Shutdown) -> Result<()>377 pub async fn shutdown(&self, how: Shutdown) -> Result<()> { 378 if self.read_shutdown.load(Ordering::SeqCst) && self.write_shutdown.load(Ordering::SeqCst) { 379 return Ok(()); 380 } 381 382 if how == Shutdown::Write || how == Shutdown::Both { 383 self.write_shutdown.store(true, Ordering::SeqCst); 384 } 385 386 if (how == Shutdown::Read || how == Shutdown::Both) 387 && !self.read_shutdown.swap(true, Ordering::SeqCst) 388 { 389 self.read_notifier.notify_waiters(); 390 } 391 392 if how == Shutdown::Both 393 || (self.read_shutdown.load(Ordering::SeqCst) 394 && self.write_shutdown.load(Ordering::SeqCst)) 395 { 396 // Reset the stream 397 // https://tools.ietf.org/html/rfc6525 398 self.send_reset_request(self.stream_identifier).await?; 399 } 400 401 Ok(()) 402 } 403 404 /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream. buffered_amount(&self) -> usize405 pub fn buffered_amount(&self) -> usize { 406 self.buffered_amount.load(Ordering::SeqCst) 407 } 408 409 /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is 410 /// considered "low." Defaults to 0. buffered_amount_low_threshold(&self) -> usize411 pub fn buffered_amount_low_threshold(&self) -> usize { 412 self.buffered_amount_low.load(Ordering::SeqCst) 413 } 414 415 /// set_buffered_amount_low_threshold is used to update the threshold. 416 /// See buffered_amount_low_threshold(). set_buffered_amount_low_threshold(&self, th: usize)417 pub fn set_buffered_amount_low_threshold(&self, th: usize) { 418 self.buffered_amount_low.store(th, Ordering::SeqCst); 419 } 420 421 /// on_buffered_amount_low sets the callback handler which would be called when the number of 422 /// bytes of outgoing data buffered is lower than the threshold. on_buffered_amount_low(&self, f: OnBufferedAmountLowFn)423 pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { 424 self.on_buffered_amount_low 425 .store(Some(Arc::new(Mutex::new(f)))); 426 } 427 428 /// This method is called by association's read_loop (go-)routine to notify this stream 429 /// of the specified amount of outgoing data has been delivered to the peer. on_buffer_released(&self, n_bytes_released: i64)430 pub(crate) async fn on_buffer_released(&self, n_bytes_released: i64) { 431 if n_bytes_released <= 0 { 432 return; 433 } 434 435 let from_amount = self.buffered_amount.load(Ordering::SeqCst); 436 let new_amount = if from_amount < n_bytes_released as usize { 437 self.buffered_amount.store(0, Ordering::SeqCst); 438 log::error!( 439 "[{}] released buffer size {} should be <= {}", 440 self.name, 441 n_bytes_released, 442 0, 443 ); 444 0 445 } else { 446 self.buffered_amount 447 .fetch_sub(n_bytes_released as usize, Ordering::SeqCst); 448 449 from_amount - n_bytes_released as usize 450 }; 451 452 let buffered_amount_low = self.buffered_amount_low.load(Ordering::SeqCst); 453 454 log::trace!( 455 "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}", 456 self.name, 457 new_amount, 458 from_amount, 459 buffered_amount_low, 460 ); 461 462 if from_amount > buffered_amount_low && new_amount <= buffered_amount_low { 463 if let Some(handler) = &*self.on_buffered_amount_low.load() { 464 let mut f = handler.lock().await; 465 f().await; 466 } 467 } 468 } 469 470 /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to 471 /// be read (once chunk is complete). get_num_bytes_in_reassembly_queue(&self) -> usize472 pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { 473 // No lock is required as it reads the size with atomic load function. 474 let reassembly_queue = self.reassembly_queue.lock().await; 475 reassembly_queue.get_num_bytes() 476 } 477 478 /// get_state atomically returns the state of the Association. get_state(&self) -> AssociationState479 fn get_state(&self) -> AssociationState { 480 self.state.load(Ordering::SeqCst).into() 481 } 482 awake_write_loop(&self)483 fn awake_write_loop(&self) { 484 //log::debug!("[{}] awake_write_loop_ch.notify_one", self.name); 485 if let Some(awake_write_loop_ch) = &self.awake_write_loop_ch { 486 let _ = awake_write_loop_ch.try_send(()); 487 } 488 } 489 send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()>490 async fn send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()> { 491 let state = self.get_state(); 492 if state != AssociationState::Established { 493 return Err(Error::ErrPayloadDataStateNotExist); 494 } 495 496 // NOTE: append is used here instead of push in order to prevent chunks interlacing. 497 self.pending_queue.append(chunks).await; 498 499 self.awake_write_loop(); 500 Ok(()) 501 } 502 send_reset_request(&self, stream_identifier: u16) -> Result<()>503 async fn send_reset_request(&self, stream_identifier: u16) -> Result<()> { 504 let state = self.get_state(); 505 if state != AssociationState::Established { 506 return Err(Error::ErrResetPacketInStateNotExist); 507 } 508 509 // Create DATA chunk which only contains valid stream identifier with 510 // nil userData and use it as a EOS from the stream. 511 let c = ChunkPayloadData { 512 stream_identifier, 513 beginning_fragment: true, 514 ending_fragment: true, 515 user_data: Bytes::new(), 516 ..Default::default() 517 }; 518 519 self.pending_queue.push(c).await; 520 521 self.awake_write_loop(); 522 Ok(()) 523 } 524 } 525 526 /// Default capacity of the temporary read buffer used by [`PollStream`]. 527 const DEFAULT_READ_BUF_SIZE: usize = 8192; 528 529 /// State of the read `Future` in [`PollStream`]. 530 enum ReadFut { 531 /// Nothing in progress. 532 Idle, 533 /// Reading data from the underlying stream. 534 Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>), 535 /// Finished reading, but there's unread data in the temporary buffer. 536 RemainingData(Vec<u8>), 537 } 538 539 enum ShutdownFut { 540 /// Nothing in progress. 541 Idle, 542 /// Reading data from the underlying stream. 543 ShuttingDown(Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>>), 544 /// Shutdown future has run 545 Done, 546 Errored(crate::error::Error), 547 } 548 549 impl ReadFut { 550 /// Gets a mutable reference to the future stored inside `Reading(future)`. 551 /// 552 /// # Panics 553 /// 554 /// Panics if `ReadFut` variant is not `Reading`. get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>555 fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> { 556 match self { 557 ReadFut::Reading(ref mut fut) => fut, 558 _ => panic!("expected ReadFut to be Reading"), 559 } 560 } 561 } 562 563 impl ShutdownFut { 564 /// Gets a mutable reference to the future stored inside `ShuttingDown(future)`. 565 /// 566 /// # Panics 567 /// 568 /// Panics if `ShutdownFut` variant is not `ShuttingDown`. get_shutting_down_mut( &mut self, ) -> &mut Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>>569 fn get_shutting_down_mut( 570 &mut self, 571 ) -> &mut Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>> { 572 match self { 573 ShutdownFut::ShuttingDown(ref mut fut) => fut, 574 _ => panic!("expected ShutdownFut to be ShuttingDown"), 575 } 576 } 577 } 578 579 /// A wrapper around around [`Stream`], which implements [`AsyncRead`] and 580 /// [`AsyncWrite`]. 581 /// 582 /// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an 583 /// additional overhead. 584 pub struct PollStream { 585 stream: Arc<Stream>, 586 587 read_fut: ReadFut, 588 write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>>>>>, 589 shutdown_fut: ShutdownFut, 590 591 read_buf_cap: usize, 592 } 593 594 impl PollStream { 595 /// Constructs a new `PollStream`. 596 /// 597 /// # Examples 598 /// 599 /// ``` 600 /// use webrtc_sctp::stream::{Stream, PollStream}; 601 /// use std::sync::Arc; 602 /// 603 /// let stream = Arc::new(Stream::default()); 604 /// let poll_stream = PollStream::new(stream); 605 /// ``` new(stream: Arc<Stream>) -> Self606 pub fn new(stream: Arc<Stream>) -> Self { 607 Self { 608 stream, 609 read_fut: ReadFut::Idle, 610 write_fut: None, 611 shutdown_fut: ShutdownFut::Idle, 612 read_buf_cap: DEFAULT_READ_BUF_SIZE, 613 } 614 } 615 616 /// Get back the inner stream. 617 #[must_use] into_inner(self) -> Arc<Stream>618 pub fn into_inner(self) -> Arc<Stream> { 619 self.stream 620 } 621 622 /// Obtain a clone of the inner stream. 623 #[must_use] clone_inner(&self) -> Arc<Stream>624 pub fn clone_inner(&self) -> Arc<Stream> { 625 self.stream.clone() 626 } 627 628 /// stream_identifier returns the Stream identifier associated to the stream. stream_identifier(&self) -> u16629 pub fn stream_identifier(&self) -> u16 { 630 self.stream.stream_identifier 631 } 632 633 /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream. buffered_amount(&self) -> usize634 pub fn buffered_amount(&self) -> usize { 635 self.stream.buffered_amount.load(Ordering::SeqCst) 636 } 637 638 /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is 639 /// considered "low." Defaults to 0. buffered_amount_low_threshold(&self) -> usize640 pub fn buffered_amount_low_threshold(&self) -> usize { 641 self.stream.buffered_amount_low.load(Ordering::SeqCst) 642 } 643 644 /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to 645 /// be read (once chunk is complete). get_num_bytes_in_reassembly_queue(&self) -> usize646 pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize { 647 // No lock is required as it reads the size with atomic load function. 648 let reassembly_queue = self.stream.reassembly_queue.lock().await; 649 reassembly_queue.get_num_bytes() 650 } 651 652 /// Set the capacity of the temporary read buffer (default: 8192). set_read_buf_capacity(&mut self, capacity: usize)653 pub fn set_read_buf_capacity(&mut self, capacity: usize) { 654 self.read_buf_cap = capacity 655 } 656 } 657 658 impl AsyncRead for PollStream { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>659 fn poll_read( 660 mut self: Pin<&mut Self>, 661 cx: &mut Context<'_>, 662 buf: &mut ReadBuf<'_>, 663 ) -> Poll<io::Result<()>> { 664 if buf.remaining() == 0 { 665 return Poll::Ready(Ok(())); 666 } 667 668 let fut = match self.read_fut { 669 ReadFut::Idle => { 670 // read into a temporary buffer because `buf` has an unonymous lifetime, which can 671 // be shorter than the lifetime of `read_fut`. 672 let stream = self.stream.clone(); 673 let mut temp_buf = vec![0; self.read_buf_cap]; 674 self.read_fut = ReadFut::Reading(Box::pin(async move { 675 stream.read(temp_buf.as_mut_slice()).await.map(|n| { 676 temp_buf.truncate(n); 677 temp_buf 678 }) 679 })); 680 self.read_fut.get_reading_mut() 681 } 682 ReadFut::Reading(ref mut fut) => fut, 683 ReadFut::RemainingData(ref mut data) => { 684 let remaining = buf.remaining(); 685 let len = std::cmp::min(data.len(), remaining); 686 buf.put_slice(&data[..len]); 687 if data.len() > remaining { 688 // ReadFut remains to be RemainingData 689 data.drain(0..len); 690 } else { 691 self.read_fut = ReadFut::Idle; 692 } 693 return Poll::Ready(Ok(())); 694 } 695 }; 696 697 loop { 698 match fut.as_mut().poll(cx) { 699 Poll::Pending => return Poll::Pending, 700 // retry immediately upon empty data or incomplete chunks 701 // since there's no way to setup a waker. 702 Poll::Ready(Err(Error::ErrTryAgain)) => {} 703 // EOF has been reached => don't touch buf and just return Ok 704 Poll::Ready(Err(Error::ErrEof)) => { 705 self.read_fut = ReadFut::Idle; 706 return Poll::Ready(Ok(())); 707 } 708 Poll::Ready(Err(e)) => { 709 self.read_fut = ReadFut::Idle; 710 return Poll::Ready(Err(e.into())); 711 } 712 Poll::Ready(Ok(mut temp_buf)) => { 713 let remaining = buf.remaining(); 714 let len = std::cmp::min(temp_buf.len(), remaining); 715 buf.put_slice(&temp_buf[..len]); 716 if temp_buf.len() > remaining { 717 temp_buf.drain(0..len); 718 self.read_fut = ReadFut::RemainingData(temp_buf); 719 } else { 720 self.read_fut = ReadFut::Idle; 721 } 722 return Poll::Ready(Ok(())); 723 } 724 } 725 } 726 } 727 } 728 729 impl AsyncWrite for PollStream { poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>730 fn poll_write( 731 mut self: Pin<&mut Self>, 732 cx: &mut Context<'_>, 733 buf: &[u8], 734 ) -> Poll<io::Result<usize>> { 735 if buf.is_empty() { 736 return Poll::Ready(Ok(0)); 737 } 738 739 if let Some(fut) = self.write_fut.as_mut() { 740 match fut.as_mut().poll(cx) { 741 Poll::Pending => Poll::Pending, 742 Poll::Ready(Err(e)) => { 743 let stream = self.stream.clone(); 744 let bytes = Bytes::copy_from_slice(buf); 745 self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await })); 746 Poll::Ready(Err(e.into())) 747 } 748 // Given the data is buffered, it's okay to ignore the number of written bytes. 749 // 750 // TODO: In the long term, `stream.write` should be made sync. Then we could 751 // remove the whole `if` condition and just call `stream.write`. 752 Poll::Ready(Ok(_)) => { 753 let stream = self.stream.clone(); 754 let bytes = Bytes::copy_from_slice(buf); 755 self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await })); 756 Poll::Ready(Ok(buf.len())) 757 } 758 } 759 } else { 760 let stream = self.stream.clone(); 761 let bytes = Bytes::copy_from_slice(buf); 762 let fut = self 763 .write_fut 764 .insert(Box::pin(async move { stream.write(&bytes).await })); 765 766 match fut.as_mut().poll(cx) { 767 // If it's the first time we're polling the future, `Poll::Pending` can't be 768 // returned because that would mean the `PollStream` is not ready for writing. And 769 // this is not true since we've just created a future, which is going to write the 770 // buf to the underlying stream. 771 // 772 // It's okay to return `Poll::Ready` if the data is buffered (this is what the 773 // buffered writer and `File` do). 774 Poll::Pending => Poll::Ready(Ok(buf.len())), 775 Poll::Ready(Err(e)) => { 776 self.write_fut = None; 777 Poll::Ready(Err(e.into())) 778 } 779 Poll::Ready(Ok(n)) => { 780 self.write_fut = None; 781 Poll::Ready(Ok(n)) 782 } 783 } 784 } 785 } 786 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>787 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 788 match self.write_fut.as_mut() { 789 Some(fut) => match fut.as_mut().poll(cx) { 790 Poll::Pending => Poll::Pending, 791 Poll::Ready(Err(e)) => { 792 self.write_fut = None; 793 Poll::Ready(Err(e.into())) 794 } 795 Poll::Ready(Ok(_)) => { 796 self.write_fut = None; 797 Poll::Ready(Ok(())) 798 } 799 }, 800 None => Poll::Ready(Ok(())), 801 } 802 } 803 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>804 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 805 match self.as_mut().poll_flush(cx) { 806 Poll::Pending => return Poll::Pending, 807 Poll::Ready(_) => {} 808 } 809 let fut = match self.shutdown_fut { 810 ShutdownFut::Done => return Poll::Ready(Ok(())), 811 ShutdownFut::Errored(ref err) => return Poll::Ready(Err(err.clone().into())), 812 ShutdownFut::ShuttingDown(ref mut fut) => fut, 813 ShutdownFut::Idle => { 814 let stream = self.stream.clone(); 815 self.shutdown_fut = ShutdownFut::ShuttingDown(Box::pin(async move { 816 stream.shutdown(Shutdown::Write).await 817 })); 818 self.shutdown_fut.get_shutting_down_mut() 819 } 820 }; 821 822 match fut.as_mut().poll(cx) { 823 Poll::Pending => Poll::Pending, 824 Poll::Ready(Err(e)) => { 825 self.shutdown_fut = ShutdownFut::Errored(e.clone()); 826 Poll::Ready(Err(e.into())) 827 } 828 Poll::Ready(Ok(_)) => { 829 self.shutdown_fut = ShutdownFut::Done; 830 Poll::Ready(Ok(())) 831 } 832 } 833 } 834 } 835 836 impl Clone for PollStream { clone(&self) -> PollStream837 fn clone(&self) -> PollStream { 838 PollStream::new(self.clone_inner()) 839 } 840 } 841 842 impl fmt::Debug for PollStream { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result843 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 844 f.debug_struct("PollStream") 845 .field("stream", &self.stream) 846 .field("read_buf_cap", &self.read_buf_cap) 847 .finish() 848 } 849 } 850 851 impl AsRef<Stream> for PollStream { as_ref(&self) -> &Stream852 fn as_ref(&self) -> &Stream { 853 &self.stream 854 } 855 } 856