1 use super::*; 2 use bytes::Bytes; 3 use http_body::Body; 4 use pin_project::pin_project; 5 use std::{ 6 pin::Pin, 7 sync::{ 8 atomic::{AtomicUsize, Ordering::SeqCst}, 9 Arc, 10 }, 11 task::{ready, Context, Poll}, 12 }; 13 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 14 use tonic::transport::{server::Connected, Channel}; 15 use tower_http::map_request_body::MapRequestBodyLayer; 16 17 /// A body that tracks how many bytes passes through it 18 #[pin_project] 19 pub struct CountBytesBody<B> { 20 #[pin] 21 pub inner: B, 22 pub counter: Arc<AtomicUsize>, 23 } 24 25 impl<B> Body for CountBytesBody<B> 26 where 27 B: Body<Data = Bytes>, 28 { 29 type Data = B::Data; 30 type Error = B::Error; 31 32 fn poll_data( 33 self: Pin<&mut Self>, 34 cx: &mut Context<'_>, 35 ) -> Poll<Option<Result<Self::Data, Self::Error>>> { 36 let this = self.project(); 37 let counter: Arc<AtomicUsize> = this.counter.clone(); 38 match ready!(this.inner.poll_data(cx)) { 39 Some(Ok(chunk)) => { 40 println!("response body chunk size = {}", chunk.len()); 41 counter.fetch_add(chunk.len(), SeqCst); 42 Poll::Ready(Some(Ok(chunk))) 43 } 44 x => Poll::Ready(x), 45 } 46 } 47 48 fn poll_trailers( 49 self: Pin<&mut Self>, 50 cx: &mut Context<'_>, 51 ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> { 52 self.project().inner.poll_trailers(cx) 53 } 54 55 fn is_end_stream(&self) -> bool { 56 self.inner.is_end_stream() 57 } 58 59 fn size_hint(&self) -> http_body::SizeHint { 60 self.inner.size_hint() 61 } 62 } 63 64 #[allow(dead_code)] 65 pub fn measure_request_body_size_layer( 66 bytes_sent_counter: Arc<AtomicUsize>, 67 ) -> MapRequestBodyLayer<impl Fn(hyper::Body) -> hyper::Body + Clone> { 68 MapRequestBodyLayer::new(move |mut body: hyper::Body| { 69 let (mut tx, new_body) = hyper::Body::channel(); 70 71 let bytes_sent_counter = bytes_sent_counter.clone(); 72 tokio::spawn(async move { 73 while let Some(chunk) = body.data().await { 74 let chunk = chunk.unwrap(); 75 println!("request body chunk size = {}", chunk.len()); 76 bytes_sent_counter.fetch_add(chunk.len(), SeqCst); 77 tx.send_data(chunk).await.unwrap(); 78 } 79 80 if let Some(trailers) = body.trailers().await.unwrap() { 81 tx.send_trailers(trailers).await.unwrap(); 82 } 83 }); 84 85 new_body 86 }) 87 } 88 89 #[allow(dead_code)] 90 pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { 91 let mut client = Some(client); 92 93 Endpoint::try_from("http://[::]:50051") 94 .unwrap() 95 .connect_with_connector(service_fn(move |_: Uri| { 96 let client = client.take().unwrap(); 97 async move { Ok::<_, std::io::Error>(client) } 98 })) 99 .await 100 .unwrap() 101 } 102