1 use super::*;
2 use bytes::{Buf, Bytes};
3 use http_body::{Body as HttpBody, Frame};
4 use http_body_util::BodyExt as _;
5 use hyper_util::rt::TokioIo;
6 use pin_project::pin_project;
7 use std::{
8 pin::Pin,
9 sync::{
10 atomic::{AtomicUsize, Ordering::SeqCst},
11 Arc,
12 },
13 task::{ready, Context, Poll},
14 };
15 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16 use tonic::body::Body;
17 use tonic::codec::CompressionEncoding;
18 use tonic::transport::{server::Connected, Channel};
19 use tower_http::map_request_body::MapRequestBodyLayer;
20
21 macro_rules! parametrized_tests {
22 ($fn_name:ident, $($test_name:ident: $input:expr),+ $(,)?) => {
23 paste::paste! {
24 $(
25 #[tokio::test(flavor = "multi_thread")]
26 async fn [<$fn_name _ $test_name>]() {
27 let input = $input;
28 $fn_name(input).await;
29 }
30 )+
31 }
32 }
33 }
34
35 pub(crate) use parametrized_tests;
36
37 /// A body that tracks how many bytes passes through it
38 #[pin_project]
39 pub struct CountBytesBody<B> {
40 #[pin]
41 pub inner: B,
42 pub counter: Arc<AtomicUsize>,
43 }
44
45 impl<B> HttpBody for CountBytesBody<B>
46 where
47 B: HttpBody<Data = Bytes>,
48 {
49 type Data = B::Data;
50 type Error = B::Error;
51
poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>52 fn poll_frame(
53 self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
56 let this = self.project();
57 let counter: Arc<AtomicUsize> = this.counter.clone();
58 match ready!(this.inner.poll_frame(cx)) {
59 Some(Ok(chunk)) => {
60 println!("response body chunk size = {}", frame_data_length(&chunk));
61 counter.fetch_add(frame_data_length(&chunk), SeqCst);
62 Poll::Ready(Some(Ok(chunk)))
63 }
64 x => Poll::Ready(x),
65 }
66 }
67
is_end_stream(&self) -> bool68 fn is_end_stream(&self) -> bool {
69 self.inner.is_end_stream()
70 }
71
size_hint(&self) -> http_body::SizeHint72 fn size_hint(&self) -> http_body::SizeHint {
73 self.inner.size_hint()
74 }
75 }
76
frame_data_length(frame: &http_body::Frame<Bytes>) -> usize77 fn frame_data_length(frame: &http_body::Frame<Bytes>) -> usize {
78 if let Some(data) = frame.data_ref() {
79 data.len()
80 } else {
81 0
82 }
83 }
84
85 #[pin_project]
86 struct ChannelBody<T> {
87 #[pin]
88 rx: tokio::sync::mpsc::Receiver<Frame<T>>,
89 }
90
91 impl<T> ChannelBody<T> {
new() -> (tokio::sync::mpsc::Sender<Frame<T>>, Self)92 pub fn new() -> (tokio::sync::mpsc::Sender<Frame<T>>, Self) {
93 let (tx, rx) = tokio::sync::mpsc::channel(32);
94 (tx, Self { rx })
95 }
96 }
97
98 impl<T> HttpBody for ChannelBody<T>
99 where
100 T: Buf,
101 {
102 type Data = T;
103 type Error = tonic::Status;
104
poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>105 fn poll_frame(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
109 let frame = ready!(self.project().rx.poll_recv(cx));
110 Poll::Ready(frame.map(Ok))
111 }
112 }
113
114 #[allow(dead_code)]
measure_request_body_size_layer( bytes_sent_counter: Arc<AtomicUsize>, ) -> MapRequestBodyLayer<impl Fn(Body) -> Body + Clone>115 pub fn measure_request_body_size_layer(
116 bytes_sent_counter: Arc<AtomicUsize>,
117 ) -> MapRequestBodyLayer<impl Fn(Body) -> Body + Clone> {
118 MapRequestBodyLayer::new(move |mut body: Body| {
119 let (tx, new_body) = ChannelBody::new();
120
121 let bytes_sent_counter = bytes_sent_counter.clone();
122 tokio::spawn(async move {
123 while let Some(chunk) = body.frame().await {
124 let chunk = chunk.unwrap();
125 println!("request body chunk size = {}", frame_data_length(&chunk));
126 bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst);
127 tx.send(chunk).await.unwrap();
128 }
129 });
130
131 Body::new(new_body)
132 })
133 }
134
135 #[allow(dead_code)]
mock_io_channel(client: tokio::io::DuplexStream) -> Channel136 pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel {
137 let mut client = Some(client);
138
139 Endpoint::try_from("http://[::]:50051")
140 .unwrap()
141 .connect_with_connector(service_fn(move |_: Uri| {
142 let client = TokioIo::new(client.take().unwrap());
143 async move { Ok::<_, std::io::Error>(client) }
144 }))
145 .await
146 .unwrap()
147 }
148
149 #[derive(Clone)]
150 pub struct AssertRightEncoding {
151 encoding: CompressionEncoding,
152 }
153
154 #[allow(dead_code)]
155 impl AssertRightEncoding {
new(encoding: CompressionEncoding) -> Self156 pub fn new(encoding: CompressionEncoding) -> Self {
157 Self { encoding }
158 }
159
call<B: HttpBody>(self, req: http::Request<B>) -> http::Request<B>160 pub fn call<B: HttpBody>(self, req: http::Request<B>) -> http::Request<B> {
161 let expected = match self.encoding {
162 CompressionEncoding::Gzip => "gzip",
163 CompressionEncoding::Zstd => "zstd",
164 CompressionEncoding::Deflate => "deflate",
165 _ => panic!("unexpected encoding {:?}", self.encoding),
166 };
167 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
168
169 req
170 }
171 }
172