1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4
5 util::parametrized_tests! {
6 client_enabled_server_enabled,
7 zstd: CompressionEncoding::Zstd,
8 gzip: CompressionEncoding::Gzip,
9 deflate: CompressionEncoding::Deflate,
10 }
11
12 #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15
16 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
17
18 let request_bytes_counter = Arc::new(AtomicUsize::new(0));
19
20 #[derive(Clone)]
21 pub struct AssertRightEncoding {
22 encoding: CompressionEncoding,
23 }
24
25 #[allow(dead_code)]
26 impl AssertRightEncoding {
27 pub fn new(encoding: CompressionEncoding) -> Self {
28 Self { encoding }
29 }
30
31 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
32 let expected = match self.encoding {
33 CompressionEncoding::Gzip => "gzip",
34 CompressionEncoding::Zstd => "zstd",
35 CompressionEncoding::Deflate => "deflate",
36 _ => panic!("unexpected encoding {:?}", self.encoding),
37 };
38 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
39
40 req
41 }
42 }
43
44 tokio::spawn({
45 let request_bytes_counter = request_bytes_counter.clone();
46 async move {
47 Server::builder()
48 .layer(
49 ServiceBuilder::new()
50 .map_request(move |req| {
51 AssertRightEncoding::new(encoding).clone().call(req)
52 })
53 .layer(measure_request_body_size_layer(
54 request_bytes_counter.clone(),
55 ))
56 .into_inner(),
57 )
58 .add_service(svc)
59 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
60 .await
61 .unwrap();
62 }
63 });
64
65 let mut client =
66 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
67
68 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
69 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
70 let req = Request::new(Box::pin(stream));
71
72 client.compress_input_client_stream(req).await.unwrap();
73
74 let bytes_sent = request_bytes_counter.load(SeqCst);
75 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
76 }
77
78 util::parametrized_tests! {
79 client_disabled_server_enabled,
80 zstd: CompressionEncoding::Zstd,
81 gzip: CompressionEncoding::Gzip,
82 deflate: CompressionEncoding::Deflate,
83 }
84
85 #[allow(dead_code)]
client_disabled_server_enabled(encoding: CompressionEncoding)86 async fn client_disabled_server_enabled(encoding: CompressionEncoding) {
87 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
88
89 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
90
91 let request_bytes_counter = Arc::new(AtomicUsize::new(0));
92
93 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
94 assert!(req.headers().get("grpc-encoding").is_none());
95 req
96 }
97
98 tokio::spawn({
99 let request_bytes_counter = request_bytes_counter.clone();
100 async move {
101 Server::builder()
102 .layer(
103 ServiceBuilder::new()
104 .map_request(assert_right_encoding)
105 .layer(measure_request_body_size_layer(
106 request_bytes_counter.clone(),
107 ))
108 .into_inner(),
109 )
110 .add_service(svc)
111 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
112 .await
113 .unwrap();
114 }
115 });
116
117 let mut client = test_client::TestClient::new(mock_io_channel(client).await);
118
119 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
120 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
121 let req = Request::new(Box::pin(stream));
122
123 client.compress_input_client_stream(req).await.unwrap();
124
125 let bytes_sent = request_bytes_counter.load(SeqCst);
126 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
127 }
128
129 util::parametrized_tests! {
130 client_enabled_server_disabled,
131 zstd: CompressionEncoding::Zstd,
132 gzip: CompressionEncoding::Gzip,
133 deflate: CompressionEncoding::Deflate,
134 }
135
136 #[allow(dead_code)]
client_enabled_server_disabled(encoding: CompressionEncoding)137 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
138 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
139
140 let svc = test_server::TestServer::new(Svc::default());
141
142 tokio::spawn(async move {
143 Server::builder()
144 .add_service(svc)
145 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
146 .await
147 .unwrap();
148 });
149
150 let mut client =
151 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
152
153 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
154 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
155 let req = Request::new(Box::pin(stream));
156
157 let status = client.compress_input_client_stream(req).await.unwrap_err();
158
159 assert_eq!(status.code(), tonic::Code::Unimplemented);
160 let expected = match encoding {
161 CompressionEncoding::Gzip => "gzip",
162 CompressionEncoding::Zstd => "zstd",
163 CompressionEncoding::Deflate => "deflate",
164 _ => panic!("unexpected encoding {:?}", encoding),
165 };
166 assert_eq!(
167 status.message(),
168 format!(
169 "Content is compressed with `{}` which isn't supported",
170 expected
171 )
172 );
173 }
174
175 util::parametrized_tests! {
176 compressing_response_from_client_stream,
177 zstd: CompressionEncoding::Zstd,
178 gzip: CompressionEncoding::Gzip,
179 deflate: CompressionEncoding::Deflate,
180 }
181
182 #[allow(dead_code)]
compressing_response_from_client_stream(encoding: CompressionEncoding)183 async fn compressing_response_from_client_stream(encoding: CompressionEncoding) {
184 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
185
186 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
187
188 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
189
190 tokio::spawn({
191 let response_bytes_counter = response_bytes_counter.clone();
192 async move {
193 Server::builder()
194 .layer(
195 ServiceBuilder::new()
196 .layer(MapResponseBodyLayer::new(move |body| {
197 util::CountBytesBody {
198 inner: body,
199 counter: response_bytes_counter.clone(),
200 }
201 }))
202 .into_inner(),
203 )
204 .add_service(svc)
205 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
206 .await
207 .unwrap();
208 }
209 });
210
211 let mut client =
212 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
213
214 let req = Request::new(Box::pin(tokio_stream::empty()));
215
216 let res = client.compress_output_client_stream(req).await.unwrap();
217 let expected = match encoding {
218 CompressionEncoding::Gzip => "gzip",
219 CompressionEncoding::Zstd => "zstd",
220 CompressionEncoding::Deflate => "deflate",
221 _ => panic!("unexpected encoding {:?}", encoding),
222 };
223 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
224 let bytes_sent = response_bytes_counter.load(SeqCst);
225 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
226 }
227