1 use super::*;
2 use tonic::codec::CompressionEncoding;
3 
4 #[tokio::test(flavor = "multi_thread")]
5 async fn client_enabled_server_enabled() {
6     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
7 
8     #[derive(Clone, Copy)]
9     struct AssertCorrectAcceptEncoding<S>(S);
10 
11     impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
12     where
13         S: Service<http::Request<B>>,
14     {
15         type Response = S::Response;
16         type Error = S::Error;
17         type Future = S::Future;
18 
19         fn poll_ready(
20             &mut self,
21             cx: &mut std::task::Context<'_>,
22         ) -> std::task::Poll<Result<(), Self::Error>> {
23             self.0.poll_ready(cx)
24         }
25 
26         fn call(&mut self, req: http::Request<B>) -> Self::Future {
27             assert_eq!(
28                 req.headers().get("grpc-accept-encoding").unwrap(),
29                 "gzip,identity"
30             );
31             self.0.call(req)
32         }
33     }
34 
35     let svc =
36         test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
37 
38     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
39 
40     tokio::spawn({
41         let response_bytes_counter = response_bytes_counter.clone();
42         async move {
43             Server::builder()
44                 .layer(
45                     ServiceBuilder::new()
46                         .layer(layer_fn(AssertCorrectAcceptEncoding))
47                         .layer(MapResponseBodyLayer::new(move |body| {
48                             util::CountBytesBody {
49                                 inner: body,
50                                 counter: response_bytes_counter.clone(),
51                             }
52                         }))
53                         .into_inner(),
54                 )
55                 .add_service(svc)
56                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
57                 .await
58                 .unwrap();
59         }
60     });
61 
62     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
63         .accept_compressed(CompressionEncoding::Gzip);
64 
65     for _ in 0..3 {
66         let res = client.compress_output_unary(()).await.unwrap();
67         assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
68         let bytes_sent = response_bytes_counter.load(SeqCst);
69         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
70     }
71 }
72 
73 #[tokio::test(flavor = "multi_thread")]
74 async fn client_enabled_server_disabled() {
75     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
76 
77     let svc = test_server::TestServer::new(Svc::default());
78 
79     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
80 
81     tokio::spawn({
82         let response_bytes_counter = response_bytes_counter.clone();
83         async move {
84             Server::builder()
85                 // no compression enable on the server so responses should not be compressed
86                 .layer(
87                     ServiceBuilder::new()
88                         .layer(MapResponseBodyLayer::new(move |body| {
89                             util::CountBytesBody {
90                                 inner: body,
91                                 counter: response_bytes_counter.clone(),
92                             }
93                         }))
94                         .into_inner(),
95                 )
96                 .add_service(svc)
97                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
98                 .await
99                 .unwrap();
100         }
101     });
102 
103     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
104         .accept_compressed(CompressionEncoding::Gzip);
105 
106     let res = client.compress_output_unary(()).await.unwrap();
107 
108     assert!(res.metadata().get("grpc-encoding").is_none());
109 
110     let bytes_sent = response_bytes_counter.load(SeqCst);
111     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
112 }
113 
114 #[tokio::test(flavor = "multi_thread")]
115 async fn client_disabled() {
116     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
117 
118     #[derive(Clone, Copy)]
119     struct AssertCorrectAcceptEncoding<S>(S);
120 
121     impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
122     where
123         S: Service<http::Request<B>>,
124     {
125         type Response = S::Response;
126         type Error = S::Error;
127         type Future = S::Future;
128 
129         fn poll_ready(
130             &mut self,
131             cx: &mut std::task::Context<'_>,
132         ) -> std::task::Poll<Result<(), Self::Error>> {
133             self.0.poll_ready(cx)
134         }
135 
136         fn call(&mut self, req: http::Request<B>) -> Self::Future {
137             assert!(req.headers().get("grpc-accept-encoding").is_none());
138             self.0.call(req)
139         }
140     }
141 
142     let svc =
143         test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
144 
145     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
146 
147     tokio::spawn({
148         let response_bytes_counter = response_bytes_counter.clone();
149         async move {
150             Server::builder()
151                 .layer(
152                     ServiceBuilder::new()
153                         .layer(layer_fn(AssertCorrectAcceptEncoding))
154                         .layer(MapResponseBodyLayer::new(move |body| {
155                             util::CountBytesBody {
156                                 inner: body,
157                                 counter: response_bytes_counter.clone(),
158                             }
159                         }))
160                         .into_inner(),
161                 )
162                 .add_service(svc)
163                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
164                 .await
165                 .unwrap();
166         }
167     });
168 
169     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
170 
171     let res = client.compress_output_unary(()).await.unwrap();
172 
173     assert!(res.metadata().get("grpc-encoding").is_none());
174 
175     let bytes_sent = response_bytes_counter.load(SeqCst);
176     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
177 }
178 
179 #[tokio::test(flavor = "multi_thread")]
180 async fn server_replying_with_unsupported_encoding() {
181     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
182 
183     let svc =
184         test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
185 
186     fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> {
187         response
188             .headers_mut()
189             .insert("grpc-encoding", "br".parse().unwrap());
190         response
191     }
192 
193     tokio::spawn(async move {
194         Server::builder()
195             .layer(
196                 ServiceBuilder::new()
197                     .map_response(add_weird_content_encoding)
198                     .into_inner(),
199             )
200             .add_service(svc)
201             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
202             .await
203             .unwrap();
204     });
205 
206     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
207         .accept_compressed(CompressionEncoding::Gzip);
208     let status: Status = client.compress_output_unary(()).await.unwrap_err();
209 
210     assert_eq!(status.code(), tonic::Code::Unimplemented);
211     assert_eq!(
212         status.message(),
213         "Content is compressed with `br` which isn't supported"
214     );
215 }
216 
217 #[tokio::test(flavor = "multi_thread")]
218 async fn disabling_compression_on_single_response() {
219     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
220 
221     let svc = test_server::TestServer::new(Svc {
222         disable_compressing_on_response: true,
223     })
224     .send_compressed(CompressionEncoding::Gzip);
225 
226     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
227 
228     tokio::spawn({
229         let response_bytes_counter = response_bytes_counter.clone();
230         async move {
231             Server::builder()
232                 .layer(
233                     ServiceBuilder::new()
234                         .layer(MapResponseBodyLayer::new(move |body| {
235                             util::CountBytesBody {
236                                 inner: body,
237                                 counter: response_bytes_counter.clone(),
238                             }
239                         }))
240                         .into_inner(),
241                 )
242                 .add_service(svc)
243                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
244                 .await
245                 .unwrap();
246         }
247     });
248 
249     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
250         .accept_compressed(CompressionEncoding::Gzip);
251 
252     let res = client.compress_output_unary(()).await.unwrap();
253     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
254     let bytes_sent = response_bytes_counter.load(SeqCst);
255     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
256 }
257 
258 #[tokio::test(flavor = "multi_thread")]
259 async fn disabling_compression_on_response_but_keeping_compression_on_stream() {
260     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
261 
262     let svc = test_server::TestServer::new(Svc {
263         disable_compressing_on_response: true,
264     })
265     .send_compressed(CompressionEncoding::Gzip);
266 
267     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
268 
269     tokio::spawn({
270         let response_bytes_counter = response_bytes_counter.clone();
271         async move {
272             Server::builder()
273                 .layer(
274                     ServiceBuilder::new()
275                         .layer(MapResponseBodyLayer::new(move |body| {
276                             util::CountBytesBody {
277                                 inner: body,
278                                 counter: response_bytes_counter.clone(),
279                             }
280                         }))
281                         .into_inner(),
282                 )
283                 .add_service(svc)
284                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
285                 .await
286                 .unwrap();
287         }
288     });
289 
290     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
291         .accept_compressed(CompressionEncoding::Gzip);
292 
293     let res = client.compress_output_server_stream(()).await.unwrap();
294 
295     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
296 
297     let mut stream: Streaming<SomeData> = res.into_inner();
298 
299     stream
300         .next()
301         .await
302         .expect("stream empty")
303         .expect("item was error");
304     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
305 
306     stream
307         .next()
308         .await
309         .expect("stream empty")
310         .expect("item was error");
311     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
312 }
313 
314 #[tokio::test(flavor = "multi_thread")]
315 async fn disabling_compression_on_response_from_client_stream() {
316     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
317 
318     let svc = test_server::TestServer::new(Svc {
319         disable_compressing_on_response: true,
320     })
321     .send_compressed(CompressionEncoding::Gzip);
322 
323     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
324 
325     tokio::spawn({
326         let response_bytes_counter = response_bytes_counter.clone();
327         async move {
328             Server::builder()
329                 .layer(
330                     ServiceBuilder::new()
331                         .layer(MapResponseBodyLayer::new(move |body| {
332                             util::CountBytesBody {
333                                 inner: body,
334                                 counter: response_bytes_counter.clone(),
335                             }
336                         }))
337                         .into_inner(),
338                 )
339                 .add_service(svc)
340                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
341                 .await
342                 .unwrap();
343         }
344     });
345 
346     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
347         .accept_compressed(CompressionEncoding::Gzip);
348 
349     let req = Request::new(Box::pin(tokio_stream::empty()));
350 
351     let res = client.compress_output_client_stream(req).await.unwrap();
352     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
353     let bytes_sent = response_bytes_counter.load(SeqCst);
354     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
355 }
356