1 use super::*;
2 use tonic::codec::CompressionEncoding;
3
4 util::parametrized_tests! {
5 client_enabled_server_enabled,
6 zstd: CompressionEncoding::Zstd,
7 gzip: CompressionEncoding::Gzip,
8 deflate: CompressionEncoding::Deflate,
9 }
10
11 #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)12 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
13 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
14
15 #[derive(Clone, Copy)]
16 struct AssertCorrectAcceptEncoding<S> {
17 service: S,
18 encoding: CompressionEncoding,
19 }
20
21 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
22 where
23 S: Service<http::Request<B>>,
24 {
25 type Response = S::Response;
26 type Error = S::Error;
27 type Future = S::Future;
28
29 fn poll_ready(
30 &mut self,
31 cx: &mut std::task::Context<'_>,
32 ) -> std::task::Poll<Result<(), Self::Error>> {
33 self.service.poll_ready(cx)
34 }
35
36 fn call(&mut self, req: http::Request<B>) -> Self::Future {
37 let expected = match self.encoding {
38 CompressionEncoding::Gzip => "gzip",
39 CompressionEncoding::Zstd => "zstd",
40 CompressionEncoding::Deflate => "deflate",
41 _ => panic!("unexpected encoding {:?}", self.encoding),
42 };
43 assert_eq!(
44 req.headers()
45 .get("grpc-accept-encoding")
46 .unwrap()
47 .to_str()
48 .unwrap(),
49 format!("{},identity", expected)
50 );
51 self.service.call(req)
52 }
53 }
54
55 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
56
57 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
58
59 tokio::spawn({
60 let response_bytes_counter = response_bytes_counter.clone();
61 async move {
62 Server::builder()
63 .layer(
64 ServiceBuilder::new()
65 .layer(layer_fn(|service| AssertCorrectAcceptEncoding {
66 service,
67 encoding,
68 }))
69 .layer(MapResponseBodyLayer::new(move |body| {
70 util::CountBytesBody {
71 inner: body,
72 counter: response_bytes_counter.clone(),
73 }
74 }))
75 .into_inner(),
76 )
77 .add_service(svc)
78 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
79 .await
80 .unwrap();
81 }
82 });
83
84 let mut client =
85 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
86
87 let expected = match encoding {
88 CompressionEncoding::Gzip => "gzip",
89 CompressionEncoding::Zstd => "zstd",
90 CompressionEncoding::Deflate => "deflate",
91 _ => panic!("unexpected encoding {:?}", encoding),
92 };
93
94 for _ in 0..3 {
95 let res = client.compress_output_unary(()).await.unwrap();
96 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
97 let bytes_sent = response_bytes_counter.load(SeqCst);
98 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
99 }
100 }
101
102 util::parametrized_tests! {
103 client_enabled_server_disabled,
104 zstd: CompressionEncoding::Zstd,
105 gzip: CompressionEncoding::Gzip,
106 deflate: CompressionEncoding::Deflate,
107 }
108
109 #[allow(dead_code)]
client_enabled_server_disabled(encoding: CompressionEncoding)110 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
111 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
112
113 let svc = test_server::TestServer::new(Svc::default());
114
115 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
116
117 tokio::spawn({
118 let response_bytes_counter = response_bytes_counter.clone();
119 async move {
120 Server::builder()
121 // no compression enable on the server so responses should not be compressed
122 .layer(
123 ServiceBuilder::new()
124 .layer(MapResponseBodyLayer::new(move |body| {
125 util::CountBytesBody {
126 inner: body,
127 counter: response_bytes_counter.clone(),
128 }
129 }))
130 .into_inner(),
131 )
132 .add_service(svc)
133 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)]))
134 .await
135 .unwrap();
136 }
137 });
138
139 let mut client =
140 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
141
142 let res = client.compress_output_unary(()).await.unwrap();
143
144 assert!(res.metadata().get("grpc-encoding").is_none());
145
146 let bytes_sent = response_bytes_counter.load(SeqCst);
147 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
148 }
149
150 #[tokio::test(flavor = "multi_thread")]
client_enabled_server_disabled_multi_encoding()151 async fn client_enabled_server_disabled_multi_encoding() {
152 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
153
154 let svc = test_server::TestServer::new(Svc::default());
155
156 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
157
158 tokio::spawn({
159 let response_bytes_counter = response_bytes_counter.clone();
160 async move {
161 Server::builder()
162 // no compression enable on the server so responses should not be compressed
163 .layer(
164 ServiceBuilder::new()
165 .layer(MapResponseBodyLayer::new(move |body| {
166 util::CountBytesBody {
167 inner: body,
168 counter: response_bytes_counter.clone(),
169 }
170 }))
171 .into_inner(),
172 )
173 .add_service(svc)
174 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
175 .await
176 .unwrap();
177 }
178 });
179
180 let mut client = test_client::TestClient::new(mock_io_channel(client).await)
181 .accept_compressed(CompressionEncoding::Gzip)
182 .accept_compressed(CompressionEncoding::Zstd)
183 .accept_compressed(CompressionEncoding::Deflate);
184
185 let res = client.compress_output_unary(()).await.unwrap();
186
187 assert!(res.metadata().get("grpc-encoding").is_none());
188
189 let bytes_sent = response_bytes_counter.load(SeqCst);
190 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
191 }
192
193 util::parametrized_tests! {
194 client_disabled,
195 zstd: CompressionEncoding::Zstd,
196 gzip: CompressionEncoding::Gzip,
197 deflate: CompressionEncoding::Deflate,
198 }
199
200 #[allow(dead_code)]
client_disabled(encoding: CompressionEncoding)201 async fn client_disabled(encoding: CompressionEncoding) {
202 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
203
204 #[derive(Clone, Copy)]
205 struct AssertCorrectAcceptEncoding<S>(S);
206
207 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
208 where
209 S: Service<http::Request<B>>,
210 {
211 type Response = S::Response;
212 type Error = S::Error;
213 type Future = S::Future;
214
215 fn poll_ready(
216 &mut self,
217 cx: &mut std::task::Context<'_>,
218 ) -> std::task::Poll<Result<(), Self::Error>> {
219 self.0.poll_ready(cx)
220 }
221
222 fn call(&mut self, req: http::Request<B>) -> Self::Future {
223 assert!(req.headers().get("grpc-accept-encoding").is_none());
224 self.0.call(req)
225 }
226 }
227
228 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
229
230 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
231
232 tokio::spawn({
233 let response_bytes_counter = response_bytes_counter.clone();
234 async move {
235 Server::builder()
236 .layer(
237 ServiceBuilder::new()
238 .layer(layer_fn(AssertCorrectAcceptEncoding))
239 .layer(MapResponseBodyLayer::new(move |body| {
240 util::CountBytesBody {
241 inner: body,
242 counter: response_bytes_counter.clone(),
243 }
244 }))
245 .into_inner(),
246 )
247 .add_service(svc)
248 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
249 .await
250 .unwrap();
251 }
252 });
253
254 let mut client = test_client::TestClient::new(mock_io_channel(client).await);
255
256 let res = client.compress_output_unary(()).await.unwrap();
257
258 assert!(res.metadata().get("grpc-encoding").is_none());
259
260 let bytes_sent = response_bytes_counter.load(SeqCst);
261 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
262 }
263
264 util::parametrized_tests! {
265 server_replying_with_unsupported_encoding,
266 zstd: CompressionEncoding::Zstd,
267 gzip: CompressionEncoding::Gzip,
268 deflate: CompressionEncoding::Deflate,
269 }
270
271 #[allow(dead_code)]
server_replying_with_unsupported_encoding(encoding: CompressionEncoding)272 async fn server_replying_with_unsupported_encoding(encoding: CompressionEncoding) {
273 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
274
275 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
276
277 fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> {
278 response
279 .headers_mut()
280 .insert("grpc-encoding", "br".parse().unwrap());
281 response
282 }
283
284 tokio::spawn(async move {
285 Server::builder()
286 .layer(
287 ServiceBuilder::new()
288 .map_response(add_weird_content_encoding)
289 .into_inner(),
290 )
291 .add_service(svc)
292 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
293 .await
294 .unwrap();
295 });
296
297 let mut client =
298 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
299 let status: Status = client.compress_output_unary(()).await.unwrap_err();
300
301 assert_eq!(status.code(), tonic::Code::Unimplemented);
302 assert_eq!(
303 status.message(),
304 "Content is compressed with `br` which isn't supported"
305 );
306 }
307
308 util::parametrized_tests! {
309 disabling_compression_on_single_response,
310 zstd: CompressionEncoding::Zstd,
311 gzip: CompressionEncoding::Gzip,
312 deflate: CompressionEncoding::Deflate,
313 }
314
315 #[allow(dead_code)]
disabling_compression_on_single_response(encoding: CompressionEncoding)316 async fn disabling_compression_on_single_response(encoding: CompressionEncoding) {
317 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
318
319 let svc = test_server::TestServer::new(Svc {
320 disable_compressing_on_response: true,
321 })
322 .send_compressed(encoding);
323
324 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
325
326 tokio::spawn({
327 let response_bytes_counter = response_bytes_counter.clone();
328 async move {
329 Server::builder()
330 .layer(
331 ServiceBuilder::new()
332 .layer(MapResponseBodyLayer::new(move |body| {
333 util::CountBytesBody {
334 inner: body,
335 counter: response_bytes_counter.clone(),
336 }
337 }))
338 .into_inner(),
339 )
340 .add_service(svc)
341 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
342 .await
343 .unwrap();
344 }
345 });
346
347 let mut client =
348 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
349
350 let res = client.compress_output_unary(()).await.unwrap();
351
352 let expected = match encoding {
353 CompressionEncoding::Gzip => "gzip",
354 CompressionEncoding::Zstd => "zstd",
355 CompressionEncoding::Deflate => "deflate",
356 _ => panic!("unexpected encoding {:?}", encoding),
357 };
358 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
359
360 let bytes_sent = response_bytes_counter.load(SeqCst);
361 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
362 }
363
364 util::parametrized_tests! {
365 disabling_compression_on_response_but_keeping_compression_on_stream,
366 zstd: CompressionEncoding::Zstd,
367 gzip: CompressionEncoding::Gzip,
368 deflate: CompressionEncoding::Deflate,
369 }
370
371 #[allow(dead_code)]
disabling_compression_on_response_but_keeping_compression_on_stream( encoding: CompressionEncoding, )372 async fn disabling_compression_on_response_but_keeping_compression_on_stream(
373 encoding: CompressionEncoding,
374 ) {
375 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
376
377 let svc = test_server::TestServer::new(Svc {
378 disable_compressing_on_response: true,
379 })
380 .send_compressed(encoding);
381
382 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
383
384 tokio::spawn({
385 let response_bytes_counter = response_bytes_counter.clone();
386 async move {
387 Server::builder()
388 .layer(
389 ServiceBuilder::new()
390 .layer(MapResponseBodyLayer::new(move |body| {
391 util::CountBytesBody {
392 inner: body,
393 counter: response_bytes_counter.clone(),
394 }
395 }))
396 .into_inner(),
397 )
398 .add_service(svc)
399 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
400 .await
401 .unwrap();
402 }
403 });
404
405 let mut client =
406 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
407
408 let res = client.compress_output_server_stream(()).await.unwrap();
409
410 let expected = match encoding {
411 CompressionEncoding::Gzip => "gzip",
412 CompressionEncoding::Zstd => "zstd",
413 CompressionEncoding::Deflate => "deflate",
414 _ => panic!("unexpected encoding {:?}", encoding),
415 };
416 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
417
418 let mut stream: Streaming<SomeData> = res.into_inner();
419
420 stream
421 .next()
422 .await
423 .expect("stream empty")
424 .expect("item was error");
425 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
426
427 stream
428 .next()
429 .await
430 .expect("stream empty")
431 .expect("item was error");
432 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
433 }
434
435 util::parametrized_tests! {
436 disabling_compression_on_response_from_client_stream,
437 zstd: CompressionEncoding::Zstd,
438 gzip: CompressionEncoding::Gzip,
439 deflate: CompressionEncoding::Deflate,
440 }
441
442 #[allow(dead_code)]
disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding)443 async fn disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding) {
444 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
445
446 let svc = test_server::TestServer::new(Svc {
447 disable_compressing_on_response: true,
448 })
449 .send_compressed(encoding);
450
451 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
452
453 tokio::spawn({
454 let response_bytes_counter = response_bytes_counter.clone();
455 async move {
456 Server::builder()
457 .layer(
458 ServiceBuilder::new()
459 .layer(MapResponseBodyLayer::new(move |body| {
460 util::CountBytesBody {
461 inner: body,
462 counter: response_bytes_counter.clone(),
463 }
464 }))
465 .into_inner(),
466 )
467 .add_service(svc)
468 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
469 .await
470 .unwrap();
471 }
472 });
473
474 let mut client =
475 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
476
477 let req = Request::new(Box::pin(tokio_stream::empty()));
478
479 let res = client.compress_output_client_stream(req).await.unwrap();
480
481 let expected = match encoding {
482 CompressionEncoding::Gzip => "gzip",
483 CompressionEncoding::Zstd => "zstd",
484 CompressionEncoding::Deflate => "deflate",
485 _ => panic!("unexpected encoding {:?}", encoding),
486 };
487 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
488 let bytes_sent = response_bytes_counter.load(SeqCst);
489 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
490 }
491