10583cff8SDavid Pedersen use super::*;
2e8cb48fcSQuentin Perez use http_body::Body;
3a585a722SMarcus Griep use tonic::codec::CompressionEncoding;
40583cff8SDavid Pedersen 
5e8cb48fcSQuentin Perez util::parametrized_tests! {
6e8cb48fcSQuentin Perez     client_enabled_server_enabled,
7e8cb48fcSQuentin Perez     zstd: CompressionEncoding::Zstd,
8e8cb48fcSQuentin Perez     gzip: CompressionEncoding::Gzip,
9*79a06cc8SIlya Averyanov     deflate: CompressionEncoding::Deflate,
10e8cb48fcSQuentin Perez }
11e8cb48fcSQuentin Perez 
12e8cb48fcSQuentin Perez #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)13e8cb48fcSQuentin Perez async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
140583cff8SDavid Pedersen     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
150583cff8SDavid Pedersen 
16e8cb48fcSQuentin Perez     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
17e8cb48fcSQuentin Perez 
18e8cb48fcSQuentin Perez     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
19e8cb48fcSQuentin Perez 
20e8cb48fcSQuentin Perez     #[derive(Clone)]
21e8cb48fcSQuentin Perez     pub struct AssertRightEncoding {
22e8cb48fcSQuentin Perez         encoding: CompressionEncoding,
23e8cb48fcSQuentin Perez     }
24e8cb48fcSQuentin Perez 
25e8cb48fcSQuentin Perez     #[allow(dead_code)]
26e8cb48fcSQuentin Perez     impl AssertRightEncoding {
27e8cb48fcSQuentin Perez         pub fn new(encoding: CompressionEncoding) -> Self {
28e8cb48fcSQuentin Perez             Self { encoding }
29e8cb48fcSQuentin Perez         }
30e8cb48fcSQuentin Perez 
31e8cb48fcSQuentin Perez         pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
32e8cb48fcSQuentin Perez             let expected = match self.encoding {
33e8cb48fcSQuentin Perez                 CompressionEncoding::Gzip => "gzip",
34e8cb48fcSQuentin Perez                 CompressionEncoding::Zstd => "zstd",
35*79a06cc8SIlya Averyanov                 CompressionEncoding::Deflate => "deflate",
36e8cb48fcSQuentin Perez                 _ => panic!("unexpected encoding {:?}", self.encoding),
37e8cb48fcSQuentin Perez             };
38e8cb48fcSQuentin Perez             assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
39e8cb48fcSQuentin Perez 
40e8cb48fcSQuentin Perez             req
41e8cb48fcSQuentin Perez         }
42e8cb48fcSQuentin Perez     }
43e8cb48fcSQuentin Perez 
44e8cb48fcSQuentin Perez     tokio::spawn({
45e8cb48fcSQuentin Perez         let request_bytes_counter = request_bytes_counter.clone();
46e8cb48fcSQuentin Perez         async move {
47e8cb48fcSQuentin Perez             Server::builder()
48e8cb48fcSQuentin Perez                 .layer(
49e8cb48fcSQuentin Perez                     ServiceBuilder::new()
50e8cb48fcSQuentin Perez                         .layer(
51e8cb48fcSQuentin Perez                             ServiceBuilder::new()
52e8cb48fcSQuentin Perez                                 .map_request(move |req| {
53e8cb48fcSQuentin Perez                                     AssertRightEncoding::new(encoding).clone().call(req)
54e8cb48fcSQuentin Perez                                 })
55e8cb48fcSQuentin Perez                                 .layer(measure_request_body_size_layer(request_bytes_counter))
56e8cb48fcSQuentin Perez                                 .into_inner(),
57e8cb48fcSQuentin Perez                         )
58e8cb48fcSQuentin Perez                         .into_inner(),
59e8cb48fcSQuentin Perez                 )
60e8cb48fcSQuentin Perez                 .add_service(svc)
61e8cb48fcSQuentin Perez                 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)]))
62e8cb48fcSQuentin Perez                 .await
63e8cb48fcSQuentin Perez                 .unwrap();
64e8cb48fcSQuentin Perez         }
65e8cb48fcSQuentin Perez     });
66e8cb48fcSQuentin Perez 
67e8cb48fcSQuentin Perez     let mut client =
68e8cb48fcSQuentin Perez         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
69e8cb48fcSQuentin Perez 
70e8cb48fcSQuentin Perez     for _ in 0..3 {
71e8cb48fcSQuentin Perez         client
72e8cb48fcSQuentin Perez             .compress_input_unary(SomeData {
73e8cb48fcSQuentin Perez                 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
74e8cb48fcSQuentin Perez             })
75e8cb48fcSQuentin Perez             .await
76e8cb48fcSQuentin Perez             .unwrap();
77e8cb48fcSQuentin Perez         let bytes_sent = request_bytes_counter.load(SeqCst);
78e8cb48fcSQuentin Perez         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
79e8cb48fcSQuentin Perez     }
80e8cb48fcSQuentin Perez }
81e8cb48fcSQuentin Perez 
82e8cb48fcSQuentin Perez util::parametrized_tests! {
83e8cb48fcSQuentin Perez     client_enabled_server_enabled_multi_encoding,
84e8cb48fcSQuentin Perez     zstd: CompressionEncoding::Zstd,
85e8cb48fcSQuentin Perez     gzip: CompressionEncoding::Gzip,
86*79a06cc8SIlya Averyanov     deflate: CompressionEncoding::Deflate,
87e8cb48fcSQuentin Perez }
88e8cb48fcSQuentin Perez 
89e8cb48fcSQuentin Perez #[allow(dead_code)]
client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding)90e8cb48fcSQuentin Perez async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) {
91e8cb48fcSQuentin Perez     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
92e8cb48fcSQuentin Perez 
93e8cb48fcSQuentin Perez     let svc = test_server::TestServer::new(Svc::default())
94e8cb48fcSQuentin Perez         .accept_compressed(CompressionEncoding::Gzip)
95*79a06cc8SIlya Averyanov         .accept_compressed(CompressionEncoding::Zstd)
96*79a06cc8SIlya Averyanov         .accept_compressed(CompressionEncoding::Deflate);
970583cff8SDavid Pedersen 
980583cff8SDavid Pedersen     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
990583cff8SDavid Pedersen 
1000583cff8SDavid Pedersen     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
101*79a06cc8SIlya Averyanov         let supported_encodings = ["gzip", "zstd", "deflate"];
102e8cb48fcSQuentin Perez         let req_encoding = req.headers().get("grpc-encoding").unwrap();
103e8cb48fcSQuentin Perez         assert!(supported_encodings.iter().any(|e| e == req_encoding));
104e8cb48fcSQuentin Perez 
1050583cff8SDavid Pedersen         req
1060583cff8SDavid Pedersen     }
1070583cff8SDavid Pedersen 
1080583cff8SDavid Pedersen     tokio::spawn({
1090583cff8SDavid Pedersen         let request_bytes_counter = request_bytes_counter.clone();
1100583cff8SDavid Pedersen         async move {
1110583cff8SDavid Pedersen             Server::builder()
1120583cff8SDavid Pedersen                 .layer(
1130583cff8SDavid Pedersen                     ServiceBuilder::new()
1140583cff8SDavid Pedersen                         .layer(
1150583cff8SDavid Pedersen                             ServiceBuilder::new()
1160583cff8SDavid Pedersen                                 .map_request(assert_right_encoding)
1170583cff8SDavid Pedersen                                 .layer(measure_request_body_size_layer(request_bytes_counter))
1180583cff8SDavid Pedersen                                 .into_inner(),
1190583cff8SDavid Pedersen                         )
1200583cff8SDavid Pedersen                         .into_inner(),
1210583cff8SDavid Pedersen                 )
1220583cff8SDavid Pedersen                 .add_service(svc)
123f089e7a0Stottoto                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
1240583cff8SDavid Pedersen                 .await
1250583cff8SDavid Pedersen                 .unwrap();
1260583cff8SDavid Pedersen         }
1270583cff8SDavid Pedersen     });
1280583cff8SDavid Pedersen 
129e8cb48fcSQuentin Perez     let mut client =
130e8cb48fcSQuentin Perez         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
1310583cff8SDavid Pedersen 
1320583cff8SDavid Pedersen     for _ in 0..3 {
1330583cff8SDavid Pedersen         client
1340583cff8SDavid Pedersen             .compress_input_unary(SomeData {
1350583cff8SDavid Pedersen                 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
1360583cff8SDavid Pedersen             })
1370583cff8SDavid Pedersen             .await
1380583cff8SDavid Pedersen             .unwrap();
1390583cff8SDavid Pedersen         let bytes_sent = request_bytes_counter.load(SeqCst);
1400583cff8SDavid Pedersen         assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
1410583cff8SDavid Pedersen     }
1420583cff8SDavid Pedersen }
1430583cff8SDavid Pedersen 
144e8cb48fcSQuentin Perez parametrized_tests! {
145e8cb48fcSQuentin Perez     client_enabled_server_disabled,
146e8cb48fcSQuentin Perez     zstd: CompressionEncoding::Zstd,
147e8cb48fcSQuentin Perez     gzip: CompressionEncoding::Gzip,
148*79a06cc8SIlya Averyanov     deflate: CompressionEncoding::Deflate,
149e8cb48fcSQuentin Perez }
150e8cb48fcSQuentin Perez 
151e8cb48fcSQuentin Perez #[allow(dead_code)]
client_enabled_server_disabled(encoding: CompressionEncoding)152e8cb48fcSQuentin Perez async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
1530583cff8SDavid Pedersen     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
1540583cff8SDavid Pedersen 
1550583cff8SDavid Pedersen     let svc = test_server::TestServer::new(Svc::default());
1560583cff8SDavid Pedersen 
1570583cff8SDavid Pedersen     tokio::spawn(async move {
1580583cff8SDavid Pedersen         Server::builder()
1590583cff8SDavid Pedersen             .add_service(svc)
160f089e7a0Stottoto             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
1610583cff8SDavid Pedersen             .await
1620583cff8SDavid Pedersen             .unwrap();
1630583cff8SDavid Pedersen     });
1640583cff8SDavid Pedersen 
165e8cb48fcSQuentin Perez     let mut client =
166e8cb48fcSQuentin Perez         test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
1670583cff8SDavid Pedersen 
1680583cff8SDavid Pedersen     let status = client
1690583cff8SDavid Pedersen         .compress_input_unary(SomeData {
1700583cff8SDavid Pedersen             data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
1710583cff8SDavid Pedersen         })
1720583cff8SDavid Pedersen         .await
1730583cff8SDavid Pedersen         .unwrap_err();
1740583cff8SDavid Pedersen 
1750583cff8SDavid Pedersen     assert_eq!(status.code(), tonic::Code::Unimplemented);
176e8cb48fcSQuentin Perez     let expected = match encoding {
177e8cb48fcSQuentin Perez         CompressionEncoding::Gzip => "gzip",
178e8cb48fcSQuentin Perez         CompressionEncoding::Zstd => "zstd",
179*79a06cc8SIlya Averyanov         CompressionEncoding::Deflate => "deflate",
180e8cb48fcSQuentin Perez         _ => panic!("unexpected encoding {:?}", encoding),
181e8cb48fcSQuentin Perez     };
1820583cff8SDavid Pedersen     assert_eq!(
1830583cff8SDavid Pedersen         status.message(),
184e8cb48fcSQuentin Perez         format!(
185e8cb48fcSQuentin Perez             "Content is compressed with `{}` which isn't supported",
186e8cb48fcSQuentin Perez             expected
187e8cb48fcSQuentin Perez         )
1880583cff8SDavid Pedersen     );
1890583cff8SDavid Pedersen 
1900583cff8SDavid Pedersen     assert_eq!(
1910583cff8SDavid Pedersen         status.metadata().get("grpc-accept-encoding").unwrap(),
1920583cff8SDavid Pedersen         "identity"
1930583cff8SDavid Pedersen     );
1940583cff8SDavid Pedersen }
195e8cb48fcSQuentin Perez parametrized_tests! {
196e8cb48fcSQuentin Perez     client_mark_compressed_without_header_server_enabled,
197e8cb48fcSQuentin Perez     zstd: CompressionEncoding::Zstd,
198e8cb48fcSQuentin Perez     gzip: CompressionEncoding::Gzip,
199*79a06cc8SIlya Averyanov     deflate: CompressionEncoding::Deflate,
200e8cb48fcSQuentin Perez }
20188008191SJulien Roncaglia 
202e8cb48fcSQuentin Perez #[allow(dead_code)]
client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding)203e8cb48fcSQuentin Perez async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) {
20488008191SJulien Roncaglia     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
20588008191SJulien Roncaglia 
206e8cb48fcSQuentin Perez     let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
20788008191SJulien Roncaglia 
20888008191SJulien Roncaglia     tokio::spawn({
20988008191SJulien Roncaglia         async move {
21088008191SJulien Roncaglia             Server::builder()
21188008191SJulien Roncaglia                 .add_service(svc)
212f089e7a0Stottoto                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
21388008191SJulien Roncaglia                 .await
21488008191SJulien Roncaglia                 .unwrap();
21588008191SJulien Roncaglia         }
21688008191SJulien Roncaglia     });
21788008191SJulien Roncaglia 
21888008191SJulien Roncaglia     let mut client = test_client::TestClient::with_interceptor(
21988008191SJulien Roncaglia         mock_io_channel(client).await,
22088008191SJulien Roncaglia         move |mut req: Request<()>| {
22188008191SJulien Roncaglia             req.metadata_mut().remove("grpc-encoding");
22288008191SJulien Roncaglia             Ok(req)
22388008191SJulien Roncaglia         },
22488008191SJulien Roncaglia     )
225a585a722SMarcus Griep     .send_compressed(CompressionEncoding::Gzip);
22688008191SJulien Roncaglia 
22788008191SJulien Roncaglia     let status = client
22888008191SJulien Roncaglia         .compress_input_unary(SomeData {
22988008191SJulien Roncaglia             data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(),
23088008191SJulien Roncaglia         })
23188008191SJulien Roncaglia         .await
23288008191SJulien Roncaglia         .unwrap_err();
23388008191SJulien Roncaglia 
23488008191SJulien Roncaglia     assert_eq!(status.code(), tonic::Code::Internal);
23588008191SJulien Roncaglia     assert_eq!(
23688008191SJulien Roncaglia         status.message(),
23788008191SJulien Roncaglia         "protocol error: received message with compressed-flag but no grpc-encoding was specified"
23888008191SJulien Roncaglia     );
23988008191SJulien Roncaglia }
240