1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4
5 util::parametrized_tests! {
6 client_enabled_server_enabled,
7 zstd: CompressionEncoding::Zstd,
8 gzip: CompressionEncoding::Gzip,
9 deflate: CompressionEncoding::Deflate,
10 }
11
12 #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15
16 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
17
18 let request_bytes_counter = Arc::new(AtomicUsize::new(0));
19
20 #[derive(Clone)]
21 pub struct AssertRightEncoding {
22 encoding: CompressionEncoding,
23 }
24
25 #[allow(dead_code)]
26 impl AssertRightEncoding {
27 pub fn new(encoding: CompressionEncoding) -> Self {
28 Self { encoding }
29 }
30
31 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
32 let expected = match self.encoding {
33 CompressionEncoding::Gzip => "gzip",
34 CompressionEncoding::Zstd => "zstd",
35 CompressionEncoding::Deflate => "deflate",
36 _ => panic!("unexpected encoding {:?}", self.encoding),
37 };
38 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
39
40 req
41 }
42 }
43
44 tokio::spawn({
45 let request_bytes_counter = request_bytes_counter.clone();
46 async move {
47 Server::builder()
48 .layer(
49 ServiceBuilder::new()
50 .layer(
51 ServiceBuilder::new()
52 .map_request(move |req| {
53 AssertRightEncoding::new(encoding).clone().call(req)
54 })
55 .layer(measure_request_body_size_layer(request_bytes_counter))
56 .into_inner(),
57 )
58 .into_inner(),
59 )
60 .add_service(svc)
61 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)]))
62 .await
63 .unwrap();
64 }
65 });
66
67 let mut client =
68 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
69
70 for _ in 0..3 {
71 client
72 .compress_input_unary(SomeData {
73 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
74 })
75 .await
76 .unwrap();
77 let bytes_sent = request_bytes_counter.load(SeqCst);
78 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
79 }
80 }
81
82 util::parametrized_tests! {
83 client_enabled_server_enabled_multi_encoding,
84 zstd: CompressionEncoding::Zstd,
85 gzip: CompressionEncoding::Gzip,
86 deflate: CompressionEncoding::Deflate,
87 }
88
89 #[allow(dead_code)]
client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding)90 async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) {
91 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
92
93 let svc = test_server::TestServer::new(Svc::default())
94 .accept_compressed(CompressionEncoding::Gzip)
95 .accept_compressed(CompressionEncoding::Zstd)
96 .accept_compressed(CompressionEncoding::Deflate);
97
98 let request_bytes_counter = Arc::new(AtomicUsize::new(0));
99
100 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
101 let supported_encodings = ["gzip", "zstd", "deflate"];
102 let req_encoding = req.headers().get("grpc-encoding").unwrap();
103 assert!(supported_encodings.iter().any(|e| e == req_encoding));
104
105 req
106 }
107
108 tokio::spawn({
109 let request_bytes_counter = request_bytes_counter.clone();
110 async move {
111 Server::builder()
112 .layer(
113 ServiceBuilder::new()
114 .layer(
115 ServiceBuilder::new()
116 .map_request(assert_right_encoding)
117 .layer(measure_request_body_size_layer(request_bytes_counter))
118 .into_inner(),
119 )
120 .into_inner(),
121 )
122 .add_service(svc)
123 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
124 .await
125 .unwrap();
126 }
127 });
128
129 let mut client =
130 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
131
132 for _ in 0..3 {
133 client
134 .compress_input_unary(SomeData {
135 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
136 })
137 .await
138 .unwrap();
139 let bytes_sent = request_bytes_counter.load(SeqCst);
140 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
141 }
142 }
143
144 parametrized_tests! {
145 client_enabled_server_disabled,
146 zstd: CompressionEncoding::Zstd,
147 gzip: CompressionEncoding::Gzip,
148 deflate: CompressionEncoding::Deflate,
149 }
150
151 #[allow(dead_code)]
client_enabled_server_disabled(encoding: CompressionEncoding)152 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
153 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
154
155 let svc = test_server::TestServer::new(Svc::default());
156
157 tokio::spawn(async move {
158 Server::builder()
159 .add_service(svc)
160 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
161 .await
162 .unwrap();
163 });
164
165 let mut client =
166 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
167
168 let status = client
169 .compress_input_unary(SomeData {
170 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
171 })
172 .await
173 .unwrap_err();
174
175 assert_eq!(status.code(), tonic::Code::Unimplemented);
176 let expected = match encoding {
177 CompressionEncoding::Gzip => "gzip",
178 CompressionEncoding::Zstd => "zstd",
179 CompressionEncoding::Deflate => "deflate",
180 _ => panic!("unexpected encoding {:?}", encoding),
181 };
182 assert_eq!(
183 status.message(),
184 format!(
185 "Content is compressed with `{}` which isn't supported",
186 expected
187 )
188 );
189
190 assert_eq!(
191 status.metadata().get("grpc-accept-encoding").unwrap(),
192 "identity"
193 );
194 }
195 parametrized_tests! {
196 client_mark_compressed_without_header_server_enabled,
197 zstd: CompressionEncoding::Zstd,
198 gzip: CompressionEncoding::Gzip,
199 deflate: CompressionEncoding::Deflate,
200 }
201
202 #[allow(dead_code)]
client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding)203 async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) {
204 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
205
206 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
207
208 tokio::spawn({
209 async move {
210 Server::builder()
211 .add_service(svc)
212 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
213 .await
214 .unwrap();
215 }
216 });
217
218 let mut client = test_client::TestClient::with_interceptor(
219 mock_io_channel(client).await,
220 move |mut req: Request<()>| {
221 req.metadata_mut().remove("grpc-encoding");
222 Ok(req)
223 },
224 )
225 .send_compressed(CompressionEncoding::Gzip);
226
227 let status = client
228 .compress_input_unary(SomeData {
229 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
230 })
231 .await
232 .unwrap_err();
233
234 assert_eq!(status.code(), tonic::Code::Internal);
235 assert_eq!(
236 status.message(),
237 "protocol error: received message with compressed-flag but no grpc-encoding was specified"
238 );
239 }
240