1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4 
5 util::parametrized_tests! {
6     client_enabled_server_enabled,
7     zstd: CompressionEncoding::Zstd,
8     gzip: CompressionEncoding::Gzip,
9     deflate: CompressionEncoding::Deflate,
10 }
11 
12 #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15 
16     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
17 
18     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
19 
20     #[derive(Clone)]
21     pub struct AssertRightEncoding {
22         encoding: CompressionEncoding,
23     }
24 
25     #[allow(dead_code)]
26     impl AssertRightEncoding {
27         pub fn new(encoding: CompressionEncoding) -> Self {
28             Self { encoding }
29         }
30 
31         pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
32             let expected = match self.encoding {
33                 CompressionEncoding::Gzip => "gzip",
34                 CompressionEncoding::Zstd => "zstd",
35                 CompressionEncoding::Deflate => "deflate",
36                 _ => panic!("unexpected encoding {:?}", self.encoding),
37             };
38             assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
39 
40             req
41         }
42     }
43 
44     tokio::spawn({
45         let request_bytes_counter = request_bytes_counter.clone();
46         async move {
47             Server::builder()
48                 .layer(
49                     ServiceBuilder::new()
50                         .map_request(move |req| {
51                             AssertRightEncoding::new(encoding).clone().call(req)
52                         })
53                         .layer(measure_request_body_size_layer(
54                             request_bytes_counter.clone(),
55                         ))
56                         .into_inner(),
57                 )
58                 .add_service(svc)
59                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
60                 .await
61                 .unwrap();
62         }
63     });
64 
65     let mut client =
66         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
67 
68     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
69     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
70     let req = Request::new(Box::pin(stream));
71 
72     client.compress_input_client_stream(req).await.unwrap();
73 
74     let bytes_sent = request_bytes_counter.load(SeqCst);
75     assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
76 }
77 
78 util::parametrized_tests! {
79     client_disabled_server_enabled,
80     zstd: CompressionEncoding::Zstd,
81     gzip: CompressionEncoding::Gzip,
82     deflate: CompressionEncoding::Deflate,
83 }
84 
85 #[allow(dead_code)]
client_disabled_server_enabled(encoding: CompressionEncoding)86 async fn client_disabled_server_enabled(encoding: CompressionEncoding) {
87     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
88 
89     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
90 
91     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
92 
93     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
94         assert!(req.headers().get("grpc-encoding").is_none());
95         req
96     }
97 
98     tokio::spawn({
99         let request_bytes_counter = request_bytes_counter.clone();
100         async move {
101             Server::builder()
102                 .layer(
103                     ServiceBuilder::new()
104                         .map_request(assert_right_encoding)
105                         .layer(measure_request_body_size_layer(
106                             request_bytes_counter.clone(),
107                         ))
108                         .into_inner(),
109                 )
110                 .add_service(svc)
111                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
112                 .await
113                 .unwrap();
114         }
115     });
116 
117     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
118 
119     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
120     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
121     let req = Request::new(Box::pin(stream));
122 
123     client.compress_input_client_stream(req).await.unwrap();
124 
125     let bytes_sent = request_bytes_counter.load(SeqCst);
126     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
127 }
128 
129 util::parametrized_tests! {
130     client_enabled_server_disabled,
131     zstd: CompressionEncoding::Zstd,
132     gzip: CompressionEncoding::Gzip,
133     deflate: CompressionEncoding::Deflate,
134 }
135 
136 #[allow(dead_code)]
client_enabled_server_disabled(encoding: CompressionEncoding)137 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
138     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
139 
140     let svc = test_server::TestServer::new(Svc::default());
141 
142     tokio::spawn(async move {
143         Server::builder()
144             .add_service(svc)
145             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
146             .await
147             .unwrap();
148     });
149 
150     let mut client =
151         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
152 
153     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
154     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
155     let req = Request::new(Box::pin(stream));
156 
157     let status = client.compress_input_client_stream(req).await.unwrap_err();
158 
159     assert_eq!(status.code(), tonic::Code::Unimplemented);
160     let expected = match encoding {
161         CompressionEncoding::Gzip => "gzip",
162         CompressionEncoding::Zstd => "zstd",
163         CompressionEncoding::Deflate => "deflate",
164         _ => panic!("unexpected encoding {:?}", encoding),
165     };
166     assert_eq!(
167         status.message(),
168         format!(
169             "Content is compressed with `{}` which isn't supported",
170             expected
171         )
172     );
173 }
174 
175 util::parametrized_tests! {
176     compressing_response_from_client_stream,
177     zstd: CompressionEncoding::Zstd,
178     gzip: CompressionEncoding::Gzip,
179     deflate: CompressionEncoding::Deflate,
180 }
181 
182 #[allow(dead_code)]
compressing_response_from_client_stream(encoding: CompressionEncoding)183 async fn compressing_response_from_client_stream(encoding: CompressionEncoding) {
184     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
185 
186     let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
187 
188     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
189 
190     tokio::spawn({
191         let response_bytes_counter = response_bytes_counter.clone();
192         async move {
193             Server::builder()
194                 .layer(
195                     ServiceBuilder::new()
196                         .layer(MapResponseBodyLayer::new(move |body| {
197                             util::CountBytesBody {
198                                 inner: body,
199                                 counter: response_bytes_counter.clone(),
200                             }
201                         }))
202                         .into_inner(),
203                 )
204                 .add_service(svc)
205                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
206                 .await
207                 .unwrap();
208         }
209     });
210 
211     let mut client =
212         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
213 
214     let req = Request::new(Box::pin(tokio_stream::empty()));
215 
216     let res = client.compress_output_client_stream(req).await.unwrap();
217     let expected = match encoding {
218         CompressionEncoding::Gzip => "gzip",
219         CompressionEncoding::Zstd => "zstd",
220         CompressionEncoding::Deflate => "deflate",
221         _ => panic!("unexpected encoding {:?}", encoding),
222     };
223     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
224     let bytes_sent = response_bytes_counter.load(SeqCst);
225     assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
226 }
227