1 use super::*;
2 use http_body::Body as _;
3 
4 #[tokio::test(flavor = "multi_thread")]
5 async fn client_enabled_server_enabled() {
6     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
7 
8     let svc = test_server::TestServer::new(Svc::default()).accept_gzip();
9 
10     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
11 
12     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
13         assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
14         req
15     }
16 
17     tokio::spawn({
18         let request_bytes_counter = request_bytes_counter.clone();
19         async move {
20             Server::builder()
21                 .layer(
22                     ServiceBuilder::new()
23                         .map_request(assert_right_encoding)
24                         .layer(measure_request_body_size_layer(
25                             request_bytes_counter.clone(),
26                         ))
27                         .into_inner(),
28                 )
29                 .add_service(svc)
30                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
31                     MockStream(server),
32                 )]))
33                 .await
34                 .unwrap();
35         }
36     });
37 
38     let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip();
39 
40     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
41     let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
42     let req = Request::new(Box::pin(stream));
43 
44     client.compress_input_client_stream(req).await.unwrap();
45 
46     let bytes_sent = request_bytes_counter.load(SeqCst);
47     assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
48 }
49 
50 #[tokio::test(flavor = "multi_thread")]
51 async fn client_disabled_server_enabled() {
52     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
53 
54     let svc = test_server::TestServer::new(Svc::default()).accept_gzip();
55 
56     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
57 
58     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
59         assert!(req.headers().get("grpc-encoding").is_none());
60         req
61     }
62 
63     tokio::spawn({
64         let request_bytes_counter = request_bytes_counter.clone();
65         async move {
66             Server::builder()
67                 .layer(
68                     ServiceBuilder::new()
69                         .map_request(assert_right_encoding)
70                         .layer(measure_request_body_size_layer(
71                             request_bytes_counter.clone(),
72                         ))
73                         .into_inner(),
74                 )
75                 .add_service(svc)
76                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
77                     MockStream(server),
78                 )]))
79                 .await
80                 .unwrap();
81         }
82     });
83 
84     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
85 
86     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
87     let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
88     let req = Request::new(Box::pin(stream));
89 
90     client.compress_input_client_stream(req).await.unwrap();
91 
92     let bytes_sent = request_bytes_counter.load(SeqCst);
93     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
94 }
95 
96 #[tokio::test(flavor = "multi_thread")]
97 async fn client_enabled_server_disabled() {
98     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
99 
100     let svc = test_server::TestServer::new(Svc::default());
101 
102     tokio::spawn(async move {
103         Server::builder()
104             .add_service(svc)
105             .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
106                 MockStream(server),
107             )]))
108             .await
109             .unwrap();
110     });
111 
112     let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip();
113 
114     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
115     let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
116     let req = Request::new(Box::pin(stream));
117 
118     let status = client.compress_input_client_stream(req).await.unwrap_err();
119 
120     assert_eq!(status.code(), tonic::Code::Unimplemented);
121     assert_eq!(
122         status.message(),
123         "Content is compressed with `gzip` which isn't supported"
124     );
125 }
126 
127 #[tokio::test(flavor = "multi_thread")]
128 async fn compressing_response_from_client_stream() {
129     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
130 
131     let svc = test_server::TestServer::new(Svc::default()).send_gzip();
132 
133     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
134 
135     tokio::spawn({
136         let response_bytes_counter = response_bytes_counter.clone();
137         async move {
138             Server::builder()
139                 .layer(
140                     ServiceBuilder::new()
141                         .layer(MapResponseBodyLayer::new(move |body| {
142                             util::CountBytesBody {
143                                 inner: body,
144                                 counter: response_bytes_counter.clone(),
145                             }
146                         }))
147                         .into_inner(),
148                 )
149                 .add_service(svc)
150                 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(
151                     MockStream(server),
152                 )]))
153                 .await
154                 .unwrap();
155         }
156     });
157 
158     let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip();
159 
160     let stream = futures::stream::iter(vec![]);
161     let req = Request::new(Box::pin(stream));
162 
163     let res = client.compress_output_client_stream(req).await.unwrap();
164     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
165     let bytes_sent = response_bytes_counter.load(SeqCst);
166     assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
167 }
168