xref: /tonic/tests/compression/src/util.rs (revision 79a06cc8)
1 use super::*;
2 use bytes::{Buf, Bytes};
3 use http_body::{Body as HttpBody, Frame};
4 use http_body_util::BodyExt as _;
5 use hyper_util::rt::TokioIo;
6 use pin_project::pin_project;
7 use std::{
8     pin::Pin,
9     sync::{
10         atomic::{AtomicUsize, Ordering::SeqCst},
11         Arc,
12     },
13     task::{ready, Context, Poll},
14 };
15 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16 use tonic::body::Body;
17 use tonic::codec::CompressionEncoding;
18 use tonic::transport::{server::Connected, Channel};
19 use tower_http::map_request_body::MapRequestBodyLayer;
20 
21 macro_rules! parametrized_tests {
22     ($fn_name:ident, $($test_name:ident: $input:expr),+ $(,)?) => {
23         paste::paste! {
24             $(
25                 #[tokio::test(flavor = "multi_thread")]
26                 async fn [<$fn_name _ $test_name>]() {
27                     let input = $input;
28                     $fn_name(input).await;
29                 }
30             )+
31         }
32     }
33 }
34 
35 pub(crate) use parametrized_tests;
36 
37 /// A body that tracks how many bytes passes through it
38 #[pin_project]
39 pub struct CountBytesBody<B> {
40     #[pin]
41     pub inner: B,
42     pub counter: Arc<AtomicUsize>,
43 }
44 
45 impl<B> HttpBody for CountBytesBody<B>
46 where
47     B: HttpBody<Data = Bytes>,
48 {
49     type Data = B::Data;
50     type Error = B::Error;
51 
poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>52     fn poll_frame(
53         self: Pin<&mut Self>,
54         cx: &mut Context<'_>,
55     ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
56         let this = self.project();
57         let counter: Arc<AtomicUsize> = this.counter.clone();
58         match ready!(this.inner.poll_frame(cx)) {
59             Some(Ok(chunk)) => {
60                 println!("response body chunk size = {}", frame_data_length(&chunk));
61                 counter.fetch_add(frame_data_length(&chunk), SeqCst);
62                 Poll::Ready(Some(Ok(chunk)))
63             }
64             x => Poll::Ready(x),
65         }
66     }
67 
is_end_stream(&self) -> bool68     fn is_end_stream(&self) -> bool {
69         self.inner.is_end_stream()
70     }
71 
size_hint(&self) -> http_body::SizeHint72     fn size_hint(&self) -> http_body::SizeHint {
73         self.inner.size_hint()
74     }
75 }
76 
frame_data_length(frame: &http_body::Frame<Bytes>) -> usize77 fn frame_data_length(frame: &http_body::Frame<Bytes>) -> usize {
78     if let Some(data) = frame.data_ref() {
79         data.len()
80     } else {
81         0
82     }
83 }
84 
85 #[pin_project]
86 struct ChannelBody<T> {
87     #[pin]
88     rx: tokio::sync::mpsc::Receiver<Frame<T>>,
89 }
90 
91 impl<T> ChannelBody<T> {
new() -> (tokio::sync::mpsc::Sender<Frame<T>>, Self)92     pub fn new() -> (tokio::sync::mpsc::Sender<Frame<T>>, Self) {
93         let (tx, rx) = tokio::sync::mpsc::channel(32);
94         (tx, Self { rx })
95     }
96 }
97 
98 impl<T> HttpBody for ChannelBody<T>
99 where
100     T: Buf,
101 {
102     type Data = T;
103     type Error = tonic::Status;
104 
poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>105     fn poll_frame(
106         self: Pin<&mut Self>,
107         cx: &mut Context<'_>,
108     ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
109         let frame = ready!(self.project().rx.poll_recv(cx));
110         Poll::Ready(frame.map(Ok))
111     }
112 }
113 
114 #[allow(dead_code)]
measure_request_body_size_layer( bytes_sent_counter: Arc<AtomicUsize>, ) -> MapRequestBodyLayer<impl Fn(Body) -> Body + Clone>115 pub fn measure_request_body_size_layer(
116     bytes_sent_counter: Arc<AtomicUsize>,
117 ) -> MapRequestBodyLayer<impl Fn(Body) -> Body + Clone> {
118     MapRequestBodyLayer::new(move |mut body: Body| {
119         let (tx, new_body) = ChannelBody::new();
120 
121         let bytes_sent_counter = bytes_sent_counter.clone();
122         tokio::spawn(async move {
123             while let Some(chunk) = body.frame().await {
124                 let chunk = chunk.unwrap();
125                 println!("request body chunk size = {}", frame_data_length(&chunk));
126                 bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst);
127                 tx.send(chunk).await.unwrap();
128             }
129         });
130 
131         Body::new(new_body)
132     })
133 }
134 
135 #[allow(dead_code)]
mock_io_channel(client: tokio::io::DuplexStream) -> Channel136 pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel {
137     let mut client = Some(client);
138 
139     Endpoint::try_from("http://[::]:50051")
140         .unwrap()
141         .connect_with_connector(service_fn(move |_: Uri| {
142             let client = TokioIo::new(client.take().unwrap());
143             async move { Ok::<_, std::io::Error>(client) }
144         }))
145         .await
146         .unwrap()
147 }
148 
149 #[derive(Clone)]
150 pub struct AssertRightEncoding {
151     encoding: CompressionEncoding,
152 }
153 
154 #[allow(dead_code)]
155 impl AssertRightEncoding {
new(encoding: CompressionEncoding) -> Self156     pub fn new(encoding: CompressionEncoding) -> Self {
157         Self { encoding }
158     }
159 
call<B: HttpBody>(self, req: http::Request<B>) -> http::Request<B>160     pub fn call<B: HttpBody>(self, req: http::Request<B>) -> http::Request<B> {
161         let expected = match self.encoding {
162             CompressionEncoding::Gzip => "gzip",
163             CompressionEncoding::Zstd => "zstd",
164             CompressionEncoding::Deflate => "deflate",
165             _ => panic!("unexpected encoding {:?}", self.encoding),
166         };
167         assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
168 
169         req
170     }
171 }
172