1 use super::*;
2 
3 #[tokio::test(flavor = "multi_thread")]
4 async fn client_enabled_server_enabled() {
5     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
6 
7     #[derive(Clone, Copy)]
8     struct AssertCorrectAcceptEncoding<S>(S);
9 
10     impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
11     where
12         S: Service<http::Request<B>>,
13     {
14         type Response = S::Response;
15         type Error = S::Error;
16         type Future = S::Future;
17 
18         fn poll_ready(
19             &mut self,
20             cx: &mut std::task::Context<'_>,
21         ) -> std::task::Poll<Result<(), Self::Error>> {
22             self.0.poll_ready(cx)
23         }
24 
25         fn call(&mut self, req: http::Request<B>) -> Self::Future {
26             assert_eq!(
27                 req.headers().get("grpc-accept-encoding").unwrap(),
28                 "gzip,identity"
29             );
30             self.0.call(req)
31         }
32     }
33 
34     let svc = test_server::TestServer::new(Svc::default()).send_gzip();
35 
36     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
37 
38     tokio::spawn({
39         let response_bytes_counter = response_bytes_counter.clone();
40         async move {
41             Server::builder()
42                 .layer(
43                     ServiceBuilder::new()
44                         .layer(layer_fn(AssertCorrectAcceptEncoding))
45                         .layer(MapResponseBodyLayer::new(move |body| {
46                             util::CountBytesBody {
47                                 inner: body,
48                                 counter: response_bytes_counter.clone(),
49                             }
50                         }))
51                         .into_inner(),
52                 )
53                 .add_service(svc)
54                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
55                     MockStream(server),
56                 )]))
57                 .await
58                 .unwrap();
59         }
60     });
61 
62     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
63 
64     for _ in 0..3 {
65         let res = client.compress_output_unary(()).await.unwrap();
66         assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
67         let bytes_sent = response_bytes_counter.load(SeqCst);
68         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
69     }
70 }
71 
72 #[tokio::test(flavor = "multi_thread")]
73 async fn client_enabled_server_disabled() {
74     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
75 
76     let svc = test_server::TestServer::new(Svc::default());
77 
78     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
79 
80     tokio::spawn({
81         let response_bytes_counter = response_bytes_counter.clone();
82         async move {
83             Server::builder()
84                 // no compression enable on the server so responses should not be compressed
85                 .layer(
86                     ServiceBuilder::new()
87                         .layer(MapResponseBodyLayer::new(move |body| {
88                             util::CountBytesBody {
89                                 inner: body,
90                                 counter: response_bytes_counter.clone(),
91                             }
92                         }))
93                         .into_inner(),
94                 )
95                 .add_service(svc)
96                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
97                     MockStream(server),
98                 )]))
99                 .await
100                 .unwrap();
101         }
102     });
103 
104     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
105 
106     let res = client.compress_output_unary(()).await.unwrap();
107 
108     assert!(res.metadata().get("grpc-encoding").is_none());
109 
110     let bytes_sent = response_bytes_counter.load(SeqCst);
111     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
112 }
113 
114 #[tokio::test(flavor = "multi_thread")]
115 async fn client_disabled() {
116     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
117 
118     #[derive(Clone, Copy)]
119     struct AssertCorrectAcceptEncoding<S>(S);
120 
121     impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S>
122     where
123         S: Service<http::Request<B>>,
124     {
125         type Response = S::Response;
126         type Error = S::Error;
127         type Future = S::Future;
128 
129         fn poll_ready(
130             &mut self,
131             cx: &mut std::task::Context<'_>,
132         ) -> std::task::Poll<Result<(), Self::Error>> {
133             self.0.poll_ready(cx)
134         }
135 
136         fn call(&mut self, req: http::Request<B>) -> Self::Future {
137             assert!(req.headers().get("grpc-accept-encoding").is_none());
138             self.0.call(req)
139         }
140     }
141 
142     let svc = test_server::TestServer::new(Svc::default()).send_gzip();
143 
144     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
145 
146     tokio::spawn({
147         let response_bytes_counter = response_bytes_counter.clone();
148         async move {
149             Server::builder()
150                 .layer(
151                     ServiceBuilder::new()
152                         .layer(layer_fn(AssertCorrectAcceptEncoding))
153                         .layer(MapResponseBodyLayer::new(move |body| {
154                             util::CountBytesBody {
155                                 inner: body,
156                                 counter: response_bytes_counter.clone(),
157                             }
158                         }))
159                         .into_inner(),
160                 )
161                 .add_service(svc)
162                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
163                     MockStream(server),
164                 )]))
165                 .await
166                 .unwrap();
167         }
168     });
169 
170     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
171 
172     let res = client.compress_output_unary(()).await.unwrap();
173 
174     assert!(res.metadata().get("grpc-encoding").is_none());
175 
176     let bytes_sent = response_bytes_counter.load(SeqCst);
177     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
178 }
179 
180 #[tokio::test(flavor = "multi_thread")]
181 async fn server_replying_with_unsupported_encoding() {
182     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
183 
184     let svc = test_server::TestServer::new(Svc::default()).send_gzip();
185 
186     fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> {
187         response
188             .headers_mut()
189             .insert("grpc-encoding", "br".parse().unwrap());
190         response
191     }
192 
193     tokio::spawn(async move {
194         Server::builder()
195             .layer(
196                 ServiceBuilder::new()
197                     .map_response(add_weird_content_encoding)
198                     .into_inner(),
199             )
200             .add_service(svc)
201             .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
202                 MockStream(server),
203             )]))
204             .await
205             .unwrap();
206     });
207 
208     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
209     let status: Status = client.compress_output_unary(()).await.unwrap_err();
210 
211     assert_eq!(status.code(), tonic::Code::Unimplemented);
212     assert_eq!(
213         status.message(),
214         "Content is compressed with `br` which isn't supported"
215     );
216 }
217 
218 #[tokio::test(flavor = "multi_thread")]
219 async fn disabling_compression_on_single_response() {
220     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
221 
222     let svc = test_server::TestServer::new(Svc {
223         disable_compressing_on_response: true,
224     })
225     .send_gzip();
226 
227     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
228 
229     tokio::spawn({
230         let response_bytes_counter = response_bytes_counter.clone();
231         async move {
232             Server::builder()
233                 .layer(
234                     ServiceBuilder::new()
235                         .layer(MapResponseBodyLayer::new(move |body| {
236                             util::CountBytesBody {
237                                 inner: body,
238                                 counter: response_bytes_counter.clone(),
239                             }
240                         }))
241                         .into_inner(),
242                 )
243                 .add_service(svc)
244                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
245                     MockStream(server),
246                 )]))
247                 .await
248                 .unwrap();
249         }
250     });
251 
252     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
253 
254     let res = client.compress_output_unary(()).await.unwrap();
255     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
256     let bytes_sent = response_bytes_counter.load(SeqCst);
257     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
258 }
259 
260 #[tokio::test(flavor = "multi_thread")]
261 async fn disabling_compression_on_response_but_keeping_compression_on_stream() {
262     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
263 
264     let svc = test_server::TestServer::new(Svc {
265         disable_compressing_on_response: true,
266     })
267     .send_gzip();
268 
269     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
270 
271     tokio::spawn({
272         let response_bytes_counter = response_bytes_counter.clone();
273         async move {
274             Server::builder()
275                 .layer(
276                     ServiceBuilder::new()
277                         .layer(MapResponseBodyLayer::new(move |body| {
278                             util::CountBytesBody {
279                                 inner: body,
280                                 counter: response_bytes_counter.clone(),
281                             }
282                         }))
283                         .into_inner(),
284                 )
285                 .add_service(svc)
286                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
287                     MockStream(server),
288                 )]))
289                 .await
290                 .unwrap();
291         }
292     });
293 
294     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
295 
296     let res = client.compress_output_server_stream(()).await.unwrap();
297 
298     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
299 
300     let mut stream: Streaming<SomeData> = res.into_inner();
301 
302     stream
303         .next()
304         .await
305         .expect("stream empty")
306         .expect("item was error");
307     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
308 
309     stream
310         .next()
311         .await
312         .expect("stream empty")
313         .expect("item was error");
314     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
315 }
316 
317 #[tokio::test(flavor = "multi_thread")]
318 async fn disabling_compression_on_response_from_client_stream() {
319     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
320 
321     let svc = test_server::TestServer::new(Svc {
322         disable_compressing_on_response: true,
323     })
324     .send_gzip();
325 
326     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
327 
328     tokio::spawn({
329         let response_bytes_counter = response_bytes_counter.clone();
330         async move {
331             Server::builder()
332                 .layer(
333                     ServiceBuilder::new()
334                         .layer(MapResponseBodyLayer::new(move |body| {
335                             util::CountBytesBody {
336                                 inner: body,
337                                 counter: response_bytes_counter.clone(),
338                             }
339                         }))
340                         .into_inner(),
341                 )
342                 .add_service(svc)
343                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
344                     MockStream(server),
345                 )]))
346                 .await
347                 .unwrap();
348         }
349     });
350 
351     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
352 
353     let stream = futures::stream::iter(vec![]);
354     let req = Request::new(Box::pin(stream));
355 
356     let res = client.compress_output_client_stream(req).await.unwrap();
357     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
358     let bytes_sent = response_bytes_counter.load(SeqCst);
359     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
360 }
361