xref: /tonic/tests/compression/src/util.rs (revision ff71e893)
1 use super::*;
2 use bytes::Bytes;
3 use http_body::Body;
4 use pin_project::pin_project;
5 use std::{
6     pin::Pin,
7     sync::{
8         atomic::{AtomicUsize, Ordering::SeqCst},
9         Arc,
10     },
11     task::{ready, Context, Poll},
12 };
13 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14 use tonic::transport::{server::Connected, Channel};
15 use tower_http::map_request_body::MapRequestBodyLayer;
16 
17 /// A body that tracks how many bytes passes through it
18 #[pin_project]
19 pub struct CountBytesBody<B> {
20     #[pin]
21     pub inner: B,
22     pub counter: Arc<AtomicUsize>,
23 }
24 
25 impl<B> Body for CountBytesBody<B>
26 where
27     B: Body<Data = Bytes>,
28 {
29     type Data = B::Data;
30     type Error = B::Error;
31 
32     fn poll_data(
33         self: Pin<&mut Self>,
34         cx: &mut Context<'_>,
35     ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
36         let this = self.project();
37         let counter: Arc<AtomicUsize> = this.counter.clone();
38         match ready!(this.inner.poll_data(cx)) {
39             Some(Ok(chunk)) => {
40                 println!("response body chunk size = {}", chunk.len());
41                 counter.fetch_add(chunk.len(), SeqCst);
42                 Poll::Ready(Some(Ok(chunk)))
43             }
44             x => Poll::Ready(x),
45         }
46     }
47 
48     fn poll_trailers(
49         self: Pin<&mut Self>,
50         cx: &mut Context<'_>,
51     ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
52         self.project().inner.poll_trailers(cx)
53     }
54 
55     fn is_end_stream(&self) -> bool {
56         self.inner.is_end_stream()
57     }
58 
59     fn size_hint(&self) -> http_body::SizeHint {
60         self.inner.size_hint()
61     }
62 }
63 
64 #[allow(dead_code)]
65 pub fn measure_request_body_size_layer(
66     bytes_sent_counter: Arc<AtomicUsize>,
67 ) -> MapRequestBodyLayer<impl Fn(hyper::Body) -> hyper::Body + Clone> {
68     MapRequestBodyLayer::new(move |mut body: hyper::Body| {
69         let (mut tx, new_body) = hyper::Body::channel();
70 
71         let bytes_sent_counter = bytes_sent_counter.clone();
72         tokio::spawn(async move {
73             while let Some(chunk) = body.data().await {
74                 let chunk = chunk.unwrap();
75                 println!("request body chunk size = {}", chunk.len());
76                 bytes_sent_counter.fetch_add(chunk.len(), SeqCst);
77                 tx.send_data(chunk).await.unwrap();
78             }
79 
80             if let Some(trailers) = body.trailers().await.unwrap() {
81                 tx.send_trailers(trailers).await.unwrap();
82             }
83         });
84 
85         new_body
86     })
87 }
88 
89 #[allow(dead_code)]
90 pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel {
91     let mut client = Some(client);
92 
93     Endpoint::try_from("http://[::]:50051")
94         .unwrap()
95         .connect_with_connector(service_fn(move |_: Uri| {
96             let client = client.take().unwrap();
97             async move { Ok::<_, std::io::Error>(client) }
98         }))
99         .await
100         .unwrap()
101 }
102