1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4 
5 util::parametrized_tests! {
6     client_enabled_server_enabled,
7     zstd: CompressionEncoding::Zstd,
8     gzip: CompressionEncoding::Gzip,
9     deflate: CompressionEncoding::Deflate,
10 }
11 
12 #[allow(dead_code)]
13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15 
16     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
17 
18     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
19 
20     #[derive(Clone)]
21     pub struct AssertRightEncoding {
22         encoding: CompressionEncoding,
23     }
24 
25     #[allow(dead_code)]
26     impl AssertRightEncoding {
27         pub fn new(encoding: CompressionEncoding) -> Self {
28             Self { encoding }
29         }
30 
31         pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
32             let expected = match self.encoding {
33                 CompressionEncoding::Gzip => "gzip",
34                 CompressionEncoding::Zstd => "zstd",
35                 CompressionEncoding::Deflate => "deflate",
36                 _ => panic!("unexpected encoding {:?}", self.encoding),
37             };
38             assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
39 
40             req
41         }
42     }
43 
44     tokio::spawn({
45         let request_bytes_counter = request_bytes_counter.clone();
46         async move {
47             Server::builder()
48                 .layer(
49                     ServiceBuilder::new()
50                         .layer(
51                             ServiceBuilder::new()
52                                 .map_request(move |req| {
53                                     AssertRightEncoding::new(encoding).clone().call(req)
54                                 })
55                                 .layer(measure_request_body_size_layer(request_bytes_counter))
56                                 .into_inner(),
57                         )
58                         .into_inner(),
59                 )
60                 .add_service(svc)
61                 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)]))
62                 .await
63                 .unwrap();
64         }
65     });
66 
67     let mut client =
68         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
69 
70     for _ in 0..3 {
71         client
72             .compress_input_unary(SomeData {
73                 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
74             })
75             .await
76             .unwrap();
77         let bytes_sent = request_bytes_counter.load(SeqCst);
78         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
79     }
80 }
81 
82 util::parametrized_tests! {
83     client_enabled_server_enabled_multi_encoding,
84     zstd: CompressionEncoding::Zstd,
85     gzip: CompressionEncoding::Gzip,
86     deflate: CompressionEncoding::Deflate,
87 }
88 
89 #[allow(dead_code)]
90 async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) {
91     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
92 
93     let svc = test_server::TestServer::new(Svc::default())
94         .accept_compressed(CompressionEncoding::Gzip)
95         .accept_compressed(CompressionEncoding::Zstd)
96         .accept_compressed(CompressionEncoding::Deflate);
97 
98     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
99 
100     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
101         let supported_encodings = ["gzip", "zstd", "deflate"];
102         let req_encoding = req.headers().get("grpc-encoding").unwrap();
103         assert!(supported_encodings.iter().any(|e| e == req_encoding));
104 
105         req
106     }
107 
108     tokio::spawn({
109         let request_bytes_counter = request_bytes_counter.clone();
110         async move {
111             Server::builder()
112                 .layer(
113                     ServiceBuilder::new()
114                         .layer(
115                             ServiceBuilder::new()
116                                 .map_request(assert_right_encoding)
117                                 .layer(measure_request_body_size_layer(request_bytes_counter))
118                                 .into_inner(),
119                         )
120                         .into_inner(),
121                 )
122                 .add_service(svc)
123                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
124                 .await
125                 .unwrap();
126         }
127     });
128 
129     let mut client =
130         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
131 
132     for _ in 0..3 {
133         client
134             .compress_input_unary(SomeData {
135                 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
136             })
137             .await
138             .unwrap();
139         let bytes_sent = request_bytes_counter.load(SeqCst);
140         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
141     }
142 }
143 
144 parametrized_tests! {
145     client_enabled_server_disabled,
146     zstd: CompressionEncoding::Zstd,
147     gzip: CompressionEncoding::Gzip,
148     deflate: CompressionEncoding::Deflate,
149 }
150 
151 #[allow(dead_code)]
152 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
153     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
154 
155     let svc = test_server::TestServer::new(Svc::default());
156 
157     tokio::spawn(async move {
158         Server::builder()
159             .add_service(svc)
160             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
161             .await
162             .unwrap();
163     });
164 
165     let mut client =
166         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
167 
168     let status = client
169         .compress_input_unary(SomeData {
170             data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
171         })
172         .await
173         .unwrap_err();
174 
175     assert_eq!(status.code(), tonic::Code::Unimplemented);
176     let expected = match encoding {
177         CompressionEncoding::Gzip => "gzip",
178         CompressionEncoding::Zstd => "zstd",
179         CompressionEncoding::Deflate => "deflate",
180         _ => panic!("unexpected encoding {:?}", encoding),
181     };
182     assert_eq!(
183         status.message(),
184         format!(
185             "Content is compressed with `{}` which isn't supported",
186             expected
187         )
188     );
189 
190     assert_eq!(
191         status.metadata().get("grpc-accept-encoding").unwrap(),
192         "identity"
193     );
194 }
195 parametrized_tests! {
196     client_mark_compressed_without_header_server_enabled,
197     zstd: CompressionEncoding::Zstd,
198     gzip: CompressionEncoding::Gzip,
199     deflate: CompressionEncoding::Deflate,
200 }
201 
202 #[allow(dead_code)]
203 async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) {
204     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
205 
206     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
207 
208     tokio::spawn({
209         async move {
210             Server::builder()
211                 .add_service(svc)
212                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
213                 .await
214                 .unwrap();
215         }
216     });
217 
218     let mut client = test_client::TestClient::with_interceptor(
219         mock_io_channel(client).await,
220         move |mut req: Request<()>| {
221             req.metadata_mut().remove("grpc-encoding");
222             Ok(req)
223         },
224     )
225     .send_compressed(CompressionEncoding::Gzip);
226 
227     let status = client
228         .compress_input_unary(SomeData {
229             data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
230         })
231         .await
232         .unwrap_err();
233 
234     assert_eq!(status.code(), tonic::Code::Internal);
235     assert_eq!(
236         status.message(),
237         "protocol error: received message with compressed-flag but no grpc-encoding was specified"
238     );
239 }
240