1 use super::*;
2 use tonic::codec::CompressionEncoding;
3 
4 util::parametrized_tests! {
5     client_enabled_server_enabled,
6     zstd: CompressionEncoding::Zstd,
7     gzip: CompressionEncoding::Gzip,
8     deflate: CompressionEncoding::Deflate,
9 }
10 
11 #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)12 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
13     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
14 
15     #[derive(Clone, Copy)]
16     struct AssertCorrectAcceptEncoding<S> {
17         service: S,
18         encoding: CompressionEncoding,
19     }
20 
21     impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
22     where
23         S: Service<http::Request<B>>,
24     {
25         type Response = S::Response;
26         type Error = S::Error;
27         type Future = S::Future;
28 
29         fn poll_ready(
30             &mut self,
31             cx: &mut std::task::Context<'_>,
32         ) -> std::task::Poll<Result<(), Self::Error>> {
33             self.service.poll_ready(cx)
34         }
35 
36         fn call(&mut self, req: http::Request<B>) -> Self::Future {
37             let expected = match self.encoding {
38                 CompressionEncoding::Gzip => "gzip",
39                 CompressionEncoding::Zstd => "zstd",
40                 CompressionEncoding::Deflate => "deflate",
41                 _ => panic!("unexpected encoding {:?}", self.encoding),
42             };
43             assert_eq!(
44                 req.headers()
45                     .get("grpc-accept-encoding")
46                     .unwrap()
47                     .to_str()
48                     .unwrap(),
49                 format!("{},identity", expected)
50             );
51             self.service.call(req)
52         }
53     }
54 
55     let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
56 
57     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
58 
59     tokio::spawn({
60         let response_bytes_counter = response_bytes_counter.clone();
61         async move {
62             Server::builder()
63                 .layer(
64                     ServiceBuilder::new()
65                         .layer(layer_fn(|service| AssertCorrectAcceptEncoding {
66                             service,
67                             encoding,
68                         }))
69                         .layer(MapResponseBodyLayer::new(move |body| {
70                             util::CountBytesBody {
71                                 inner: body,
72                                 counter: response_bytes_counter.clone(),
73                             }
74                         }))
75                         .into_inner(),
76                 )
77                 .add_service(svc)
78                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
79                 .await
80                 .unwrap();
81         }
82     });
83 
84     let mut client =
85         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
86 
87     let expected = match encoding {
88         CompressionEncoding::Gzip => "gzip",
89         CompressionEncoding::Zstd => "zstd",
90         CompressionEncoding::Deflate => "deflate",
91         _ => panic!("unexpected encoding {:?}", encoding),
92     };
93 
94     for _ in 0..3 {
95         let res = client.compress_output_unary(()).await.unwrap();
96         assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
97         let bytes_sent = response_bytes_counter.load(SeqCst);
98         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
99     }
100 }
101 
102 util::parametrized_tests! {
103     client_enabled_server_disabled,
104     zstd: CompressionEncoding::Zstd,
105     gzip: CompressionEncoding::Gzip,
106     deflate: CompressionEncoding::Deflate,
107 }
108 
109 #[allow(dead_code)]
client_enabled_server_disabled(encoding: CompressionEncoding)110 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
111     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
112 
113     let svc = test_server::TestServer::new(Svc::default());
114 
115     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
116 
117     tokio::spawn({
118         let response_bytes_counter = response_bytes_counter.clone();
119         async move {
120             Server::builder()
121                 // no compression enable on the server so responses should not be compressed
122                 .layer(
123                     ServiceBuilder::new()
124                         .layer(MapResponseBodyLayer::new(move |body| {
125                             util::CountBytesBody {
126                                 inner: body,
127                                 counter: response_bytes_counter.clone(),
128                             }
129                         }))
130                         .into_inner(),
131                 )
132                 .add_service(svc)
133                 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)]))
134                 .await
135                 .unwrap();
136         }
137     });
138 
139     let mut client =
140         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
141 
142     let res = client.compress_output_unary(()).await.unwrap();
143 
144     assert!(res.metadata().get("grpc-encoding").is_none());
145 
146     let bytes_sent = response_bytes_counter.load(SeqCst);
147     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
148 }
149 
150 #[tokio::test(flavor = "multi_thread")]
client_enabled_server_disabled_multi_encoding()151 async fn client_enabled_server_disabled_multi_encoding() {
152     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
153 
154     let svc = test_server::TestServer::new(Svc::default());
155 
156     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
157 
158     tokio::spawn({
159         let response_bytes_counter = response_bytes_counter.clone();
160         async move {
161             Server::builder()
162                 // no compression enable on the server so responses should not be compressed
163                 .layer(
164                     ServiceBuilder::new()
165                         .layer(MapResponseBodyLayer::new(move |body| {
166                             util::CountBytesBody {
167                                 inner: body,
168                                 counter: response_bytes_counter.clone(),
169                             }
170                         }))
171                         .into_inner(),
172                 )
173                 .add_service(svc)
174                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
175                 .await
176                 .unwrap();
177         }
178     });
179 
180     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
181         .accept_compressed(CompressionEncoding::Gzip)
182         .accept_compressed(CompressionEncoding::Zstd)
183         .accept_compressed(CompressionEncoding::Deflate);
184 
185     let res = client.compress_output_unary(()).await.unwrap();
186 
187     assert!(res.metadata().get("grpc-encoding").is_none());
188 
189     let bytes_sent = response_bytes_counter.load(SeqCst);
190     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
191 }
192 
193 util::parametrized_tests! {
194     client_disabled,
195     zstd: CompressionEncoding::Zstd,
196     gzip: CompressionEncoding::Gzip,
197     deflate: CompressionEncoding::Deflate,
198 }
199 
200 #[allow(dead_code)]
client_disabled(encoding: CompressionEncoding)201 async fn client_disabled(encoding: CompressionEncoding) {
202     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
203 
204     #[derive(Clone, Copy)]
205     struct AssertCorrectAcceptEncoding<S>(S);
206 
207     impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
208     where
209         S: Service<http::Request<B>>,
210     {
211         type Response = S::Response;
212         type Error = S::Error;
213         type Future = S::Future;
214 
215         fn poll_ready(
216             &mut self,
217             cx: &mut std::task::Context<'_>,
218         ) -> std::task::Poll<Result<(), Self::Error>> {
219             self.0.poll_ready(cx)
220         }
221 
222         fn call(&mut self, req: http::Request<B>) -> Self::Future {
223             assert!(req.headers().get("grpc-accept-encoding").is_none());
224             self.0.call(req)
225         }
226     }
227 
228     let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
229 
230     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
231 
232     tokio::spawn({
233         let response_bytes_counter = response_bytes_counter.clone();
234         async move {
235             Server::builder()
236                 .layer(
237                     ServiceBuilder::new()
238                         .layer(layer_fn(AssertCorrectAcceptEncoding))
239                         .layer(MapResponseBodyLayer::new(move |body| {
240                             util::CountBytesBody {
241                                 inner: body,
242                                 counter: response_bytes_counter.clone(),
243                             }
244                         }))
245                         .into_inner(),
246                 )
247                 .add_service(svc)
248                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
249                 .await
250                 .unwrap();
251         }
252     });
253 
254     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
255 
256     let res = client.compress_output_unary(()).await.unwrap();
257 
258     assert!(res.metadata().get("grpc-encoding").is_none());
259 
260     let bytes_sent = response_bytes_counter.load(SeqCst);
261     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
262 }
263 
264 util::parametrized_tests! {
265     server_replying_with_unsupported_encoding,
266     zstd: CompressionEncoding::Zstd,
267     gzip: CompressionEncoding::Gzip,
268     deflate: CompressionEncoding::Deflate,
269 }
270 
271 #[allow(dead_code)]
server_replying_with_unsupported_encoding(encoding: CompressionEncoding)272 async fn server_replying_with_unsupported_encoding(encoding: CompressionEncoding) {
273     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
274 
275     let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
276 
277     fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> {
278         response
279             .headers_mut()
280             .insert("grpc-encoding", "br".parse().unwrap());
281         response
282     }
283 
284     tokio::spawn(async move {
285         Server::builder()
286             .layer(
287                 ServiceBuilder::new()
288                     .map_response(add_weird_content_encoding)
289                     .into_inner(),
290             )
291             .add_service(svc)
292             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
293             .await
294             .unwrap();
295     });
296 
297     let mut client =
298         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
299     let status: Status = client.compress_output_unary(()).await.unwrap_err();
300 
301     assert_eq!(status.code(), tonic::Code::Unimplemented);
302     assert_eq!(
303         status.message(),
304         "Content is compressed with `br` which isn't supported"
305     );
306 }
307 
308 util::parametrized_tests! {
309     disabling_compression_on_single_response,
310     zstd: CompressionEncoding::Zstd,
311     gzip: CompressionEncoding::Gzip,
312     deflate: CompressionEncoding::Deflate,
313 }
314 
315 #[allow(dead_code)]
disabling_compression_on_single_response(encoding: CompressionEncoding)316 async fn disabling_compression_on_single_response(encoding: CompressionEncoding) {
317     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
318 
319     let svc = test_server::TestServer::new(Svc {
320         disable_compressing_on_response: true,
321     })
322     .send_compressed(encoding);
323 
324     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
325 
326     tokio::spawn({
327         let response_bytes_counter = response_bytes_counter.clone();
328         async move {
329             Server::builder()
330                 .layer(
331                     ServiceBuilder::new()
332                         .layer(MapResponseBodyLayer::new(move |body| {
333                             util::CountBytesBody {
334                                 inner: body,
335                                 counter: response_bytes_counter.clone(),
336                             }
337                         }))
338                         .into_inner(),
339                 )
340                 .add_service(svc)
341                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
342                 .await
343                 .unwrap();
344         }
345     });
346 
347     let mut client =
348         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
349 
350     let res = client.compress_output_unary(()).await.unwrap();
351 
352     let expected = match encoding {
353         CompressionEncoding::Gzip => "gzip",
354         CompressionEncoding::Zstd => "zstd",
355         CompressionEncoding::Deflate => "deflate",
356         _ => panic!("unexpected encoding {:?}", encoding),
357     };
358     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
359 
360     let bytes_sent = response_bytes_counter.load(SeqCst);
361     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
362 }
363 
364 util::parametrized_tests! {
365     disabling_compression_on_response_but_keeping_compression_on_stream,
366     zstd: CompressionEncoding::Zstd,
367     gzip: CompressionEncoding::Gzip,
368     deflate: CompressionEncoding::Deflate,
369 }
370 
371 #[allow(dead_code)]
disabling_compression_on_response_but_keeping_compression_on_stream( encoding: CompressionEncoding, )372 async fn disabling_compression_on_response_but_keeping_compression_on_stream(
373     encoding: CompressionEncoding,
374 ) {
375     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
376 
377     let svc = test_server::TestServer::new(Svc {
378         disable_compressing_on_response: true,
379     })
380     .send_compressed(encoding);
381 
382     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
383 
384     tokio::spawn({
385         let response_bytes_counter = response_bytes_counter.clone();
386         async move {
387             Server::builder()
388                 .layer(
389                     ServiceBuilder::new()
390                         .layer(MapResponseBodyLayer::new(move |body| {
391                             util::CountBytesBody {
392                                 inner: body,
393                                 counter: response_bytes_counter.clone(),
394                             }
395                         }))
396                         .into_inner(),
397                 )
398                 .add_service(svc)
399                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
400                 .await
401                 .unwrap();
402         }
403     });
404 
405     let mut client =
406         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
407 
408     let res = client.compress_output_server_stream(()).await.unwrap();
409 
410     let expected = match encoding {
411         CompressionEncoding::Gzip => "gzip",
412         CompressionEncoding::Zstd => "zstd",
413         CompressionEncoding::Deflate => "deflate",
414         _ => panic!("unexpected encoding {:?}", encoding),
415     };
416     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
417 
418     let mut stream: Streaming<SomeData> = res.into_inner();
419 
420     stream
421         .next()
422         .await
423         .expect("stream empty")
424         .expect("item was error");
425     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
426 
427     stream
428         .next()
429         .await
430         .expect("stream empty")
431         .expect("item was error");
432     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
433 }
434 
435 util::parametrized_tests! {
436     disabling_compression_on_response_from_client_stream,
437     zstd: CompressionEncoding::Zstd,
438     gzip: CompressionEncoding::Gzip,
439     deflate: CompressionEncoding::Deflate,
440 }
441 
442 #[allow(dead_code)]
disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding)443 async fn disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding) {
444     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
445 
446     let svc = test_server::TestServer::new(Svc {
447         disable_compressing_on_response: true,
448     })
449     .send_compressed(encoding);
450 
451     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
452 
453     tokio::spawn({
454         let response_bytes_counter = response_bytes_counter.clone();
455         async move {
456             Server::builder()
457                 .layer(
458                     ServiceBuilder::new()
459                         .layer(MapResponseBodyLayer::new(move |body| {
460                             util::CountBytesBody {
461                                 inner: body,
462                                 counter: response_bytes_counter.clone(),
463                             }
464                         }))
465                         .into_inner(),
466                 )
467                 .add_service(svc)
468                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
469                 .await
470                 .unwrap();
471         }
472     });
473 
474     let mut client =
475         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
476 
477     let req = Request::new(Box::pin(tokio_stream::empty()));
478 
479     let res = client.compress_output_client_stream(req).await.unwrap();
480 
481     let expected = match encoding {
482         CompressionEncoding::Gzip => "gzip",
483         CompressionEncoding::Zstd => "zstd",
484         CompressionEncoding::Deflate => "deflate",
485         _ => panic!("unexpected encoding {:?}", encoding),
486     };
487     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
488     let bytes_sent = response_bytes_counter.load(SeqCst);
489     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
490 }
491