1 use super::*; 2 use bytes::{Buf, Bytes}; 3 use http_body::{Body as HttpBody, Frame}; 4 use http_body_util::BodyExt as _; 5 use hyper_util::rt::TokioIo; 6 use pin_project::pin_project; 7 use std::{ 8 pin::Pin, 9 sync::{ 10 atomic::{AtomicUsize, Ordering::SeqCst}, 11 Arc, 12 }, 13 task::{ready, Context, Poll}, 14 }; 15 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 16 use tonic::body::Body; 17 use tonic::codec::CompressionEncoding; 18 use tonic::transport::{server::Connected, Channel}; 19 use tower_http::map_request_body::MapRequestBodyLayer; 20 21 macro_rules! parametrized_tests { 22 ($fn_name:ident, $($test_name:ident: $input:expr),+ $(,)?) => { 23 paste::paste! { 24 $( 25 #[tokio::test(flavor = "multi_thread")] 26 async fn [<$fn_name _ $test_name>]() { 27 let input = $input; 28 $fn_name(input).await; 29 } 30 )+ 31 } 32 } 33 } 34 35 pub(crate) use parametrized_tests; 36 37 /// A body that tracks how many bytes passes through it 38 #[pin_project] 39 pub struct CountBytesBody<B> { 40 #[pin] 41 pub inner: B, 42 pub counter: Arc<AtomicUsize>, 43 } 44 45 impl<B> HttpBody for CountBytesBody<B> 46 where 47 B: HttpBody<Data = Bytes>, 48 { 49 type Data = B::Data; 50 type Error = B::Error; 51 52 fn poll_frame( 53 self: Pin<&mut Self>, 54 cx: &mut Context<'_>, 55 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { 56 let this = self.project(); 57 let counter: Arc<AtomicUsize> = this.counter.clone(); 58 match ready!(this.inner.poll_frame(cx)) { 59 Some(Ok(chunk)) => { 60 println!("response body chunk size = {}", frame_data_length(&chunk)); 61 counter.fetch_add(frame_data_length(&chunk), SeqCst); 62 Poll::Ready(Some(Ok(chunk))) 63 } 64 x => Poll::Ready(x), 65 } 66 } 67 68 fn is_end_stream(&self) -> bool { 69 self.inner.is_end_stream() 70 } 71 72 fn size_hint(&self) -> http_body::SizeHint { 73 self.inner.size_hint() 74 } 75 } 76 77 fn frame_data_length(frame: &http_body::Frame<Bytes>) -> usize { 78 if let Some(data) = frame.data_ref() { 79 data.len() 80 } else { 81 0 82 } 83 } 84 85 #[pin_project] 86 struct ChannelBody<T> { 87 #[pin] 88 rx: tokio::sync::mpsc::Receiver<Frame<T>>, 89 } 90 91 impl<T> ChannelBody<T> { 92 pub fn new() -> (tokio::sync::mpsc::Sender<Frame<T>>, Self) { 93 let (tx, rx) = tokio::sync::mpsc::channel(32); 94 (tx, Self { rx }) 95 } 96 } 97 98 impl<T> HttpBody for ChannelBody<T> 99 where 100 T: Buf, 101 { 102 type Data = T; 103 type Error = tonic::Status; 104 105 fn poll_frame( 106 self: Pin<&mut Self>, 107 cx: &mut Context<'_>, 108 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { 109 let frame = ready!(self.project().rx.poll_recv(cx)); 110 Poll::Ready(frame.map(Ok)) 111 } 112 } 113 114 #[allow(dead_code)] 115 pub fn measure_request_body_size_layer( 116 bytes_sent_counter: Arc<AtomicUsize>, 117 ) -> MapRequestBodyLayer<impl Fn(Body) -> Body + Clone> { 118 MapRequestBodyLayer::new(move |mut body: Body| { 119 let (tx, new_body) = ChannelBody::new(); 120 121 let bytes_sent_counter = bytes_sent_counter.clone(); 122 tokio::spawn(async move { 123 while let Some(chunk) = body.frame().await { 124 let chunk = chunk.unwrap(); 125 println!("request body chunk size = {}", frame_data_length(&chunk)); 126 bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst); 127 tx.send(chunk).await.unwrap(); 128 } 129 }); 130 131 Body::new(new_body) 132 }) 133 } 134 135 #[allow(dead_code)] 136 pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { 137 let mut client = Some(client); 138 139 Endpoint::try_from("http://[::]:50051") 140 .unwrap() 141 .connect_with_connector(service_fn(move |_: Uri| { 142 let client = TokioIo::new(client.take().unwrap()); 143 async move { Ok::<_, std::io::Error>(client) } 144 })) 145 .await 146 .unwrap() 147 } 148 149 #[derive(Clone)] 150 pub struct AssertRightEncoding { 151 encoding: CompressionEncoding, 152 } 153 154 #[allow(dead_code)] 155 impl AssertRightEncoding { 156 pub fn new(encoding: CompressionEncoding) -> Self { 157 Self { encoding } 158 } 159 160 pub fn call<B: HttpBody>(self, req: http::Request<B>) -> http::Request<B> { 161 let expected = match self.encoding { 162 CompressionEncoding::Gzip => "gzip", 163 CompressionEncoding::Zstd => "zstd", 164 CompressionEncoding::Deflate => "deflate", 165 _ => panic!("unexpected encoding {:?}", self.encoding), 166 }; 167 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); 168 169 req 170 } 171 } 172