xref: /webrtc/data/src/data_channel/mod.rs (revision 84b8594c)
1 #[cfg(test)]
2 mod data_channel_test;
3 
4 use crate::error::Result;
5 use crate::{
6     error::Error, message::message_channel_ack::*, message::message_channel_open::*, message::*,
7 };
8 
9 use sctp::{
10     association::Association, chunk::chunk_payload_data::PayloadProtocolIdentifier, stream::*,
11 };
12 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13 use util::marshal::*;
14 
15 use bytes::{Buf, Bytes};
16 use std::borrow::Borrow;
17 use std::fmt;
18 use std::future::Future;
19 use std::io;
20 use std::net::Shutdown;
21 use std::pin::Pin;
22 use std::sync::atomic::{AtomicUsize, Ordering};
23 use std::sync::Arc;
24 use std::task::{Context, Poll};
25 
26 const RECEIVE_MTU: usize = 8192;
27 
28 /// Config is used to configure the data channel.
29 #[derive(Eq, PartialEq, Default, Clone, Debug)]
30 pub struct Config {
31     pub channel_type: ChannelType,
32     pub negotiated: bool,
33     pub priority: u16,
34     pub reliability_parameter: u32,
35     pub label: String,
36     pub protocol: String,
37 }
38 
39 /// DataChannel represents a data channel
40 #[derive(Debug, Default, Clone)]
41 pub struct DataChannel {
42     pub config: Config,
43     stream: Arc<Stream>,
44 
45     // stats
46     messages_sent: Arc<AtomicUsize>,
47     messages_received: Arc<AtomicUsize>,
48     bytes_sent: Arc<AtomicUsize>,
49     bytes_received: Arc<AtomicUsize>,
50 }
51 
52 impl DataChannel {
new(stream: Arc<Stream>, config: Config) -> Self53     pub fn new(stream: Arc<Stream>, config: Config) -> Self {
54         Self {
55             config,
56             stream,
57             ..Default::default()
58         }
59     }
60 
61     /// Dial opens a data channels over SCTP
dial( association: &Arc<Association>, identifier: u16, config: Config, ) -> Result<Self>62     pub async fn dial(
63         association: &Arc<Association>,
64         identifier: u16,
65         config: Config,
66     ) -> Result<Self> {
67         let stream = association
68             .open_stream(identifier, PayloadProtocolIdentifier::Binary)
69             .await?;
70 
71         Self::client(stream, config).await
72     }
73 
74     /// Accept is used to accept incoming data channels over SCTP
accept<T>( association: &Arc<Association>, config: Config, existing_channels: &[T], ) -> Result<Self> where T: Borrow<Self>,75     pub async fn accept<T>(
76         association: &Arc<Association>,
77         config: Config,
78         existing_channels: &[T],
79     ) -> Result<Self>
80     where
81         T: Borrow<Self>,
82     {
83         let stream = association
84             .accept_stream()
85             .await
86             .ok_or(Error::ErrStreamClosed)?;
87 
88         for channel in existing_channels.iter().map(|ch| ch.borrow()) {
89             if channel.stream_identifier() == stream.stream_identifier() {
90                 let ch = channel.to_owned();
91                 ch.stream
92                     .set_default_payload_type(PayloadProtocolIdentifier::Binary);
93                 return Ok(ch);
94             }
95         }
96 
97         stream.set_default_payload_type(PayloadProtocolIdentifier::Binary);
98 
99         Self::server(stream, config).await
100     }
101 
102     /// Client opens a data channel over an SCTP stream
client(stream: Arc<Stream>, config: Config) -> Result<Self>103     pub async fn client(stream: Arc<Stream>, config: Config) -> Result<Self> {
104         if !config.negotiated {
105             let msg = Message::DataChannelOpen(DataChannelOpen {
106                 channel_type: config.channel_type,
107                 priority: config.priority,
108                 reliability_parameter: config.reliability_parameter,
109                 label: config.label.bytes().collect(),
110                 protocol: config.protocol.bytes().collect(),
111             })
112             .marshal()?;
113 
114             stream
115                 .write_sctp(&msg, PayloadProtocolIdentifier::Dcep)
116                 .await?;
117         }
118         Ok(DataChannel::new(stream, config))
119     }
120 
121     /// Server accepts a data channel over an SCTP stream
server(stream: Arc<Stream>, mut config: Config) -> Result<Self>122     pub async fn server(stream: Arc<Stream>, mut config: Config) -> Result<Self> {
123         let mut buf = vec![0u8; RECEIVE_MTU];
124 
125         let (n, ppi) = stream.read_sctp(&mut buf).await?;
126 
127         if ppi != PayloadProtocolIdentifier::Dcep {
128             return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8));
129         }
130 
131         let mut read_buf = &buf[..n];
132         let msg = Message::unmarshal(&mut read_buf)?;
133 
134         if let Message::DataChannelOpen(dco) = msg {
135             config.channel_type = dco.channel_type;
136             config.priority = dco.priority;
137             config.reliability_parameter = dco.reliability_parameter;
138             config.label = String::from_utf8(dco.label)?;
139             config.protocol = String::from_utf8(dco.protocol)?;
140         } else {
141             return Err(Error::InvalidMessageType(msg.message_type() as u8));
142         };
143 
144         let data_channel = DataChannel::new(stream, config);
145 
146         data_channel.write_data_channel_ack().await?;
147         data_channel.commit_reliability_params();
148 
149         Ok(data_channel)
150     }
151 
152     /// Read reads a packet of len(p) bytes as binary data.
153     ///
154     /// See [`sctp::stream::Stream::read_sctp`].
read(&self, buf: &mut [u8]) -> Result<usize>155     pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
156         self.read_data_channel(buf).await.map(|(n, _)| n)
157     }
158 
159     /// ReadDataChannel reads a packet of len(p) bytes. It returns the number of bytes read and
160     /// `true` if the data read is a string.
161     ///
162     /// See [`sctp::stream::Stream::read_sctp`].
read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)>163     pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> {
164         loop {
165             //TODO: add handling of cancel read_data_channel
166             let (mut n, ppi) = match self.stream.read_sctp(buf).await {
167                 Ok((0, PayloadProtocolIdentifier::Unknown)) => {
168                     // The incoming stream was reset or the reading half was shutdown
169                     return Ok((0, false));
170                 }
171                 Ok((n, ppi)) => (n, ppi),
172                 Err(err) => {
173                     // Shutdown the stream and send the reset request to the remote.
174                     self.close().await?;
175                     return Err(err.into());
176                 }
177             };
178 
179             let mut is_string = false;
180             match ppi {
181                 PayloadProtocolIdentifier::Dcep => {
182                     let mut data = &buf[..n];
183                     match self.handle_dcep(&mut data).await {
184                         Ok(()) => {}
185                         Err(err) => {
186                             log::error!("Failed to handle DCEP: {:?}", err);
187                         }
188                     }
189                     continue;
190                 }
191                 PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => {
192                     is_string = true;
193                 }
194                 _ => {}
195             };
196 
197             match ppi {
198                 PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => {
199                     n = 0;
200                 }
201                 _ => {}
202             };
203 
204             self.messages_received.fetch_add(1, Ordering::SeqCst);
205             self.bytes_received.fetch_add(n, Ordering::SeqCst);
206 
207             return Ok((n, is_string));
208         }
209     }
210 
211     /// MessagesSent returns the number of messages sent
messages_sent(&self) -> usize212     pub fn messages_sent(&self) -> usize {
213         self.messages_sent.load(Ordering::SeqCst)
214     }
215 
216     /// MessagesReceived returns the number of messages received
messages_received(&self) -> usize217     pub fn messages_received(&self) -> usize {
218         self.messages_received.load(Ordering::SeqCst)
219     }
220 
221     /// BytesSent returns the number of bytes sent
bytes_sent(&self) -> usize222     pub fn bytes_sent(&self) -> usize {
223         self.bytes_sent.load(Ordering::SeqCst)
224     }
225 
226     /// BytesReceived returns the number of bytes received
bytes_received(&self) -> usize227     pub fn bytes_received(&self) -> usize {
228         self.bytes_received.load(Ordering::SeqCst)
229     }
230 
231     /// StreamIdentifier returns the Stream identifier associated to the stream.
stream_identifier(&self) -> u16232     pub fn stream_identifier(&self) -> u16 {
233         self.stream.stream_identifier()
234     }
235 
handle_dcep<B>(&self, data: &mut B) -> Result<()> where B: Buf,236     async fn handle_dcep<B>(&self, data: &mut B) -> Result<()>
237     where
238         B: Buf,
239     {
240         let msg = Message::unmarshal(data)?;
241 
242         match msg {
243             Message::DataChannelOpen(_) => {
244                 // Note: DATA_CHANNEL_OPEN message is handled inside Server() method.
245                 // Therefore, the message will not reach here.
246                 log::debug!("Received DATA_CHANNEL_OPEN");
247                 let _ = self.write_data_channel_ack().await?;
248             }
249             Message::DataChannelAck(_) => {
250                 log::debug!("Received DATA_CHANNEL_ACK");
251                 self.commit_reliability_params();
252             }
253         };
254 
255         Ok(())
256     }
257 
258     /// Write writes len(p) bytes from p as binary data
write(&self, data: &Bytes) -> Result<usize>259     pub async fn write(&self, data: &Bytes) -> Result<usize> {
260         self.write_data_channel(data, false).await
261     }
262 
263     /// WriteDataChannel writes len(p) bytes from p
write_data_channel(&self, data: &Bytes, is_string: bool) -> Result<usize>264     pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result<usize> {
265         let data_len = data.len();
266 
267         // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
268         // SCTP does not support the sending of empty user messages.  Therefore,
269         // if an empty message has to be sent, the appropriate PPID (WebRTC
270         // String Empty or WebRTC Binary Empty) is used and the SCTP user
271         // message of one zero byte is sent.  When receiving an SCTP user
272         // message with one of these PPIDs, the receiver MUST ignore the SCTP
273         // user message and process it as an empty message.
274         let ppi = match (is_string, data_len) {
275             (false, 0) => PayloadProtocolIdentifier::BinaryEmpty,
276             (false, _) => PayloadProtocolIdentifier::Binary,
277             (true, 0) => PayloadProtocolIdentifier::StringEmpty,
278             (true, _) => PayloadProtocolIdentifier::String,
279         };
280 
281         let n = if data_len == 0 {
282             let _ = self
283                 .stream
284                 .write_sctp(&Bytes::from_static(&[0]), ppi)
285                 .await?;
286             0
287         } else {
288             let n = self.stream.write_sctp(data, ppi).await?;
289             self.bytes_sent.fetch_add(n, Ordering::SeqCst);
290             n
291         };
292 
293         self.messages_sent.fetch_add(1, Ordering::SeqCst);
294         Ok(n)
295     }
296 
write_data_channel_ack(&self) -> Result<usize>297     async fn write_data_channel_ack(&self) -> Result<usize> {
298         let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
299         Ok(self
300             .stream
301             .write_sctp(&ack, PayloadProtocolIdentifier::Dcep)
302             .await?)
303     }
304 
305     /// Close closes the DataChannel and the underlying SCTP stream.
close(&self) -> Result<()>306     pub async fn close(&self) -> Result<()> {
307         // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
308         // Closing of a data channel MUST be signaled by resetting the
309         // corresponding outgoing streams [RFC6525].  This means that if one
310         // side decides to close the data channel, it resets the corresponding
311         // outgoing stream.  When the peer sees that an incoming stream was
312         // reset, it also resets its corresponding outgoing stream.  Once this
313         // is completed, the data channel is closed.  Resetting a stream sets
314         // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
315         // a corresponding notification to the application layer that the reset
316         // has been performed.  Streams are available for reuse after a reset
317         // has been performed.
318         Ok(self.stream.shutdown(Shutdown::Both).await?)
319     }
320 
321     /// BufferedAmount returns the number of bytes of data currently queued to be
322     /// sent over this stream.
buffered_amount(&self) -> usize323     pub fn buffered_amount(&self) -> usize {
324         self.stream.buffered_amount()
325     }
326 
327     /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
328     /// data that is considered "low." Defaults to 0.
buffered_amount_low_threshold(&self) -> usize329     pub fn buffered_amount_low_threshold(&self) -> usize {
330         self.stream.buffered_amount_low_threshold()
331     }
332 
333     /// SetBufferedAmountLowThreshold is used to update the threshold.
334     /// See BufferedAmountLowThreshold().
set_buffered_amount_low_threshold(&self, threshold: usize)335     pub fn set_buffered_amount_low_threshold(&self, threshold: usize) {
336         self.stream.set_buffered_amount_low_threshold(threshold)
337     }
338 
339     /// OnBufferedAmountLow sets the callback handler which would be called when the
340     /// number of bytes of outgoing data buffered is lower than the threshold.
on_buffered_amount_low(&self, f: OnBufferedAmountLowFn)341     pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
342         self.stream.on_buffered_amount_low(f)
343     }
344 
commit_reliability_params(&self)345     fn commit_reliability_params(&self) {
346         let (unordered, reliability_type) = match self.config.channel_type {
347             ChannelType::Reliable => (false, ReliabilityType::Reliable),
348             ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
349             ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
350             ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
351             ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
352             ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
353         };
354 
355         self.stream.set_reliability_params(
356             unordered,
357             reliability_type,
358             self.config.reliability_parameter,
359         );
360     }
361 }
362 
363 /// Default capacity of the temporary read buffer used by [`PollStream`].
364 const DEFAULT_READ_BUF_SIZE: usize = 8192;
365 
366 /// State of the read `Future` in [`PollStream`].
367 enum ReadFut {
368     /// Nothing in progress.
369     Idle,
370     /// Reading data from the underlying stream.
371     Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
372     /// Finished reading, but there's unread data in the temporary buffer.
373     RemainingData(Vec<u8>),
374 }
375 
376 impl ReadFut {
377     /// Gets a mutable reference to the future stored inside `Reading(future)`.
378     ///
379     /// # Panics
380     ///
381     /// Panics if `ReadFut` variant is not `Reading`.
get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>382     fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
383         match self {
384             ReadFut::Reading(ref mut fut) => fut,
385             _ => panic!("expected ReadFut to be Reading"),
386         }
387     }
388 }
389 
390 /// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and
391 /// [`AsyncWrite`].
392 ///
393 /// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
394 /// additional overhead.
395 pub struct PollDataChannel {
396     data_channel: Arc<DataChannel>,
397 
398     read_fut: ReadFut,
399     write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
400     shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
401 
402     read_buf_cap: usize,
403 }
404 
405 impl PollDataChannel {
406     /// Constructs a new `PollDataChannel`.
407     ///
408     /// # Examples
409     ///
410     /// ```
411     /// use webrtc_data::data_channel::{DataChannel, PollDataChannel, Config};
412     /// use sctp::stream::Stream;
413     /// use std::sync::Arc;
414     ///
415     /// let dc = Arc::new(DataChannel::new(Arc::new(Stream::default()), Config::default()));
416     /// let poll_dc = PollDataChannel::new(dc);
417     /// ```
new(data_channel: Arc<DataChannel>) -> Self418     pub fn new(data_channel: Arc<DataChannel>) -> Self {
419         Self {
420             data_channel,
421             read_fut: ReadFut::Idle,
422             write_fut: None,
423             shutdown_fut: None,
424             read_buf_cap: DEFAULT_READ_BUF_SIZE,
425         }
426     }
427 
428     /// Get back the inner data_channel.
into_inner(self) -> Arc<DataChannel>429     pub fn into_inner(self) -> Arc<DataChannel> {
430         self.data_channel
431     }
432 
433     /// Obtain a clone of the inner data_channel.
clone_inner(&self) -> Arc<DataChannel>434     pub fn clone_inner(&self) -> Arc<DataChannel> {
435         self.data_channel.clone()
436     }
437 
438     /// MessagesSent returns the number of messages sent
messages_sent(&self) -> usize439     pub fn messages_sent(&self) -> usize {
440         self.data_channel.messages_sent()
441     }
442 
443     /// MessagesReceived returns the number of messages received
messages_received(&self) -> usize444     pub fn messages_received(&self) -> usize {
445         self.data_channel.messages_received()
446     }
447 
448     /// BytesSent returns the number of bytes sent
bytes_sent(&self) -> usize449     pub fn bytes_sent(&self) -> usize {
450         self.data_channel.bytes_sent()
451     }
452 
453     /// BytesReceived returns the number of bytes received
bytes_received(&self) -> usize454     pub fn bytes_received(&self) -> usize {
455         self.data_channel.bytes_received()
456     }
457 
458     /// StreamIdentifier returns the Stream identifier associated to the stream.
stream_identifier(&self) -> u16459     pub fn stream_identifier(&self) -> u16 {
460         self.data_channel.stream_identifier()
461     }
462 
463     /// BufferedAmount returns the number of bytes of data currently queued to be
464     /// sent over this stream.
buffered_amount(&self) -> usize465     pub fn buffered_amount(&self) -> usize {
466         self.data_channel.buffered_amount()
467     }
468 
469     /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
470     /// data that is considered "low." Defaults to 0.
buffered_amount_low_threshold(&self) -> usize471     pub fn buffered_amount_low_threshold(&self) -> usize {
472         self.data_channel.buffered_amount_low_threshold()
473     }
474 
475     /// Set the capacity of the temporary read buffer (default: 8192).
set_read_buf_capacity(&mut self, capacity: usize)476     pub fn set_read_buf_capacity(&mut self, capacity: usize) {
477         self.read_buf_cap = capacity
478     }
479 }
480 
481 impl AsyncRead for PollDataChannel {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>482     fn poll_read(
483         mut self: Pin<&mut Self>,
484         cx: &mut Context<'_>,
485         buf: &mut ReadBuf<'_>,
486     ) -> Poll<io::Result<()>> {
487         if buf.remaining() == 0 {
488             return Poll::Ready(Ok(()));
489         }
490 
491         let fut = match self.read_fut {
492             ReadFut::Idle => {
493                 // read into a temporary buffer because `buf` has an unonymous lifetime, which can
494                 // be shorter than the lifetime of `read_fut`.
495                 let data_channel = self.data_channel.clone();
496                 let mut temp_buf = vec![0; self.read_buf_cap];
497                 self.read_fut = ReadFut::Reading(Box::pin(async move {
498                     data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
499                         temp_buf.truncate(n);
500                         temp_buf
501                     })
502                 }));
503                 self.read_fut.get_reading_mut()
504             }
505             ReadFut::Reading(ref mut fut) => fut,
506             ReadFut::RemainingData(ref mut data) => {
507                 let remaining = buf.remaining();
508                 let len = std::cmp::min(data.len(), remaining);
509                 buf.put_slice(&data[..len]);
510                 if data.len() > remaining {
511                     // ReadFut remains to be RemainingData
512                     data.drain(..len);
513                 } else {
514                     self.read_fut = ReadFut::Idle;
515                 }
516                 return Poll::Ready(Ok(()));
517             }
518         };
519 
520         loop {
521             match fut.as_mut().poll(cx) {
522                 Poll::Pending => return Poll::Pending,
523                 // retry immediately upon empty data or incomplete chunks
524                 // since there's no way to setup a waker.
525                 Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
526                 // EOF has been reached => don't touch buf and just return Ok
527                 Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
528                     self.read_fut = ReadFut::Idle;
529                     return Poll::Ready(Ok(()));
530                 }
531                 Poll::Ready(Err(e)) => {
532                     self.read_fut = ReadFut::Idle;
533                     return Poll::Ready(Err(e.into()));
534                 }
535                 Poll::Ready(Ok(mut temp_buf)) => {
536                     let remaining = buf.remaining();
537                     let len = std::cmp::min(temp_buf.len(), remaining);
538                     buf.put_slice(&temp_buf[..len]);
539                     if temp_buf.len() > remaining {
540                         temp_buf.drain(..len);
541                         self.read_fut = ReadFut::RemainingData(temp_buf);
542                     } else {
543                         self.read_fut = ReadFut::Idle;
544                     }
545                     return Poll::Ready(Ok(()));
546                 }
547             }
548         }
549     }
550 }
551 
552 impl AsyncWrite for PollDataChannel {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>553     fn poll_write(
554         mut self: Pin<&mut Self>,
555         cx: &mut Context<'_>,
556         buf: &[u8],
557     ) -> Poll<io::Result<usize>> {
558         if buf.is_empty() {
559             return Poll::Ready(Ok(0));
560         }
561 
562         if let Some(fut) = self.write_fut.as_mut() {
563             match fut.as_mut().poll(cx) {
564                 Poll::Pending => Poll::Pending,
565                 Poll::Ready(Err(e)) => {
566                     let data_channel = self.data_channel.clone();
567                     let bytes = Bytes::copy_from_slice(buf);
568                     self.write_fut =
569                         Some(Box::pin(async move { data_channel.write(&bytes).await }));
570                     Poll::Ready(Err(e.into()))
571                 }
572                 // Given the data is buffered, it's okay to ignore the number of written bytes.
573                 //
574                 // TODO: In the long term, `data_channel.write` should be made sync. Then we could
575                 // remove the whole `if` condition and just call `data_channel.write`.
576                 Poll::Ready(Ok(_)) => {
577                     let data_channel = self.data_channel.clone();
578                     let bytes = Bytes::copy_from_slice(buf);
579                     self.write_fut =
580                         Some(Box::pin(async move { data_channel.write(&bytes).await }));
581                     Poll::Ready(Ok(buf.len()))
582                 }
583             }
584         } else {
585             let data_channel = self.data_channel.clone();
586             let bytes = Bytes::copy_from_slice(buf);
587             let fut = self
588                 .write_fut
589                 .insert(Box::pin(async move { data_channel.write(&bytes).await }));
590 
591             match fut.as_mut().poll(cx) {
592                 // If it's the first time we're polling the future, `Poll::Pending` can't be
593                 // returned because that would mean the `PollDataChannel` is not ready for writing.
594                 // And this is not true since we've just created a future, which is going to write
595                 // the buf to the underlying stream.
596                 //
597                 // It's okay to return `Poll::Ready` if the data is buffered (this is what the
598                 // buffered writer and `File` do).
599                 Poll::Pending => Poll::Ready(Ok(buf.len())),
600                 Poll::Ready(Err(e)) => {
601                     self.write_fut = None;
602                     Poll::Ready(Err(e.into()))
603                 }
604                 Poll::Ready(Ok(n)) => {
605                     self.write_fut = None;
606                     Poll::Ready(Ok(n))
607                 }
608             }
609         }
610     }
611 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>612     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
613         match self.write_fut.as_mut() {
614             Some(fut) => match fut.as_mut().poll(cx) {
615                 Poll::Pending => Poll::Pending,
616                 Poll::Ready(Err(e)) => {
617                     self.write_fut = None;
618                     Poll::Ready(Err(e.into()))
619                 }
620                 Poll::Ready(Ok(_)) => {
621                     self.write_fut = None;
622                     Poll::Ready(Ok(()))
623                 }
624             },
625             None => Poll::Ready(Ok(())),
626         }
627     }
628 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>629     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
630         match self.as_mut().poll_flush(cx) {
631             Poll::Pending => return Poll::Pending,
632             Poll::Ready(_) => {}
633         }
634 
635         let fut = match self.shutdown_fut.as_mut() {
636             Some(fut) => fut,
637             None => {
638                 let data_channel = self.data_channel.clone();
639                 self.shutdown_fut.get_or_insert(Box::pin(async move {
640                     data_channel
641                         .stream
642                         .shutdown(Shutdown::Write)
643                         .await
644                         .map_err(Error::Sctp)
645                 }))
646             }
647         };
648 
649         match fut.as_mut().poll(cx) {
650             Poll::Pending => Poll::Pending,
651             Poll::Ready(Err(e)) => {
652                 self.shutdown_fut = None;
653                 Poll::Ready(Err(e.into()))
654             }
655             Poll::Ready(Ok(_)) => {
656                 self.shutdown_fut = None;
657                 Poll::Ready(Ok(()))
658             }
659         }
660     }
661 }
662 
663 impl Clone for PollDataChannel {
clone(&self) -> PollDataChannel664     fn clone(&self) -> PollDataChannel {
665         PollDataChannel::new(self.clone_inner())
666     }
667 }
668 
669 impl fmt::Debug for PollDataChannel {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result670     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
671         f.debug_struct("PollDataChannel")
672             .field("data_channel", &self.data_channel)
673             .field("read_buf_cap", &self.read_buf_cap)
674             .finish()
675     }
676 }
677 
678 impl AsRef<DataChannel> for PollDataChannel {
as_ref(&self) -> &DataChannel679     fn as_ref(&self) -> &DataChannel {
680         &self.data_channel
681     }
682 }
683