1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4 
5 util::parametrized_tests! {
6     client_enabled_server_enabled,
7     zstd: CompressionEncoding::Zstd,
8     gzip: CompressionEncoding::Gzip,
9 }
10 
11 #[allow(dead_code)]
12 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
13     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
14 
15     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
16 
17     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
18 
19     #[derive(Clone)]
20     pub struct AssertRightEncoding {
21         encoding: CompressionEncoding,
22     }
23 
24     #[allow(dead_code)]
25     impl AssertRightEncoding {
26         pub fn new(encoding: CompressionEncoding) -> Self {
27             Self { encoding }
28         }
29 
30         pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
31             let expected = match self.encoding {
32                 CompressionEncoding::Gzip => "gzip",
33                 CompressionEncoding::Zstd => "zstd",
34                 _ => panic!("unexpected encoding {:?}", self.encoding),
35             };
36             assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
37 
38             req
39         }
40     }
41 
42     tokio::spawn({
43         let request_bytes_counter = request_bytes_counter.clone();
44         async move {
45             Server::builder()
46                 .layer(
47                     ServiceBuilder::new()
48                         .layer(
49                             ServiceBuilder::new()
50                                 .map_request(move |req| {
51                                     AssertRightEncoding::new(encoding).clone().call(req)
52                                 })
53                                 .layer(measure_request_body_size_layer(request_bytes_counter))
54                                 .into_inner(),
55                         )
56                         .into_inner(),
57                 )
58                 .add_service(svc)
59                 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)]))
60                 .await
61                 .unwrap();
62         }
63     });
64 
65     let mut client =
66         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
67 
68     for _ in 0..3 {
69         client
70             .compress_input_unary(SomeData {
71                 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
72             })
73             .await
74             .unwrap();
75         let bytes_sent = request_bytes_counter.load(SeqCst);
76         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
77     }
78 }
79 
80 util::parametrized_tests! {
81     client_enabled_server_enabled_multi_encoding,
82     zstd: CompressionEncoding::Zstd,
83     gzip: CompressionEncoding::Gzip,
84 }
85 
86 #[allow(dead_code)]
87 async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) {
88     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
89 
90     let svc = test_server::TestServer::new(Svc::default())
91         .accept_compressed(CompressionEncoding::Gzip)
92         .accept_compressed(CompressionEncoding::Zstd);
93 
94     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
95 
96     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
97         let supported_encodings = ["gzip", "zstd"];
98         let req_encoding = req.headers().get("grpc-encoding").unwrap();
99         assert!(supported_encodings.iter().any(|e| e == req_encoding));
100 
101         req
102     }
103 
104     tokio::spawn({
105         let request_bytes_counter = request_bytes_counter.clone();
106         async move {
107             Server::builder()
108                 .layer(
109                     ServiceBuilder::new()
110                         .layer(
111                             ServiceBuilder::new()
112                                 .map_request(assert_right_encoding)
113                                 .layer(measure_request_body_size_layer(request_bytes_counter))
114                                 .into_inner(),
115                         )
116                         .into_inner(),
117                 )
118                 .add_service(svc)
119                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
120                 .await
121                 .unwrap();
122         }
123     });
124 
125     let mut client =
126         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
127 
128     for _ in 0..3 {
129         client
130             .compress_input_unary(SomeData {
131                 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
132             })
133             .await
134             .unwrap();
135         let bytes_sent = request_bytes_counter.load(SeqCst);
136         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
137     }
138 }
139 
140 parametrized_tests! {
141     client_enabled_server_disabled,
142     zstd: CompressionEncoding::Zstd,
143     gzip: CompressionEncoding::Gzip,
144 }
145 
146 #[allow(dead_code)]
147 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
148     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
149 
150     let svc = test_server::TestServer::new(Svc::default());
151 
152     tokio::spawn(async move {
153         Server::builder()
154             .add_service(svc)
155             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
156             .await
157             .unwrap();
158     });
159 
160     let mut client =
161         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
162 
163     let status = client
164         .compress_input_unary(SomeData {
165             data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
166         })
167         .await
168         .unwrap_err();
169 
170     assert_eq!(status.code(), tonic::Code::Unimplemented);
171     let expected = match encoding {
172         CompressionEncoding::Gzip => "gzip",
173         CompressionEncoding::Zstd => "zstd",
174         _ => panic!("unexpected encoding {:?}", encoding),
175     };
176     assert_eq!(
177         status.message(),
178         format!(
179             "Content is compressed with `{}` which isn't supported",
180             expected
181         )
182     );
183 
184     assert_eq!(
185         status.metadata().get("grpc-accept-encoding").unwrap(),
186         "identity"
187     );
188 }
189 parametrized_tests! {
190     client_mark_compressed_without_header_server_enabled,
191     zstd: CompressionEncoding::Zstd,
192     gzip: CompressionEncoding::Gzip,
193 }
194 
195 #[allow(dead_code)]
196 async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) {
197     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
198 
199     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
200 
201     tokio::spawn({
202         async move {
203             Server::builder()
204                 .add_service(svc)
205                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
206                 .await
207                 .unwrap();
208         }
209     });
210 
211     let mut client = test_client::TestClient::with_interceptor(
212         mock_io_channel(client).await,
213         move |mut req: Request<()>| {
214             req.metadata_mut().remove("grpc-encoding");
215             Ok(req)
216         },
217     )
218     .send_compressed(CompressionEncoding::Gzip);
219 
220     let status = client
221         .compress_input_unary(SomeData {
222             data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
223         })
224         .await
225         .unwrap_err();
226 
227     assert_eq!(status.code(), tonic::Code::Internal);
228     assert_eq!(
229         status.message(),
230         "protocol error: received message with compressed-flag but no grpc-encoding was specified"
231     );
232 }
233