1 use super::*;
2 use http_body::Body as _;
3 use tonic::codec::CompressionEncoding;
4 
5 #[tokio::test(flavor = "multi_thread")]
6 async fn client_enabled_server_enabled() {
7     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
8 
9     let svc =
10         test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
11 
12     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
13 
14     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
15         assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
16         req
17     }
18 
19     tokio::spawn({
20         let request_bytes_counter = request_bytes_counter.clone();
21         async move {
22             Server::builder()
23                 .layer(
24                     ServiceBuilder::new()
25                         .map_request(assert_right_encoding)
26                         .layer(measure_request_body_size_layer(
27                             request_bytes_counter.clone(),
28                         ))
29                         .into_inner(),
30                 )
31                 .add_service(svc)
32                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
33                 .await
34                 .unwrap();
35         }
36     });
37 
38     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
39         .send_compressed(CompressionEncoding::Gzip);
40 
41     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
42     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
43     let req = Request::new(Box::pin(stream));
44 
45     client.compress_input_client_stream(req).await.unwrap();
46 
47     let bytes_sent = request_bytes_counter.load(SeqCst);
48     assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
49 }
50 
51 #[tokio::test(flavor = "multi_thread")]
52 async fn client_disabled_server_enabled() {
53     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
54 
55     let svc =
56         test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
57 
58     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
59 
60     fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
61         assert!(req.headers().get("grpc-encoding").is_none());
62         req
63     }
64 
65     tokio::spawn({
66         let request_bytes_counter = request_bytes_counter.clone();
67         async move {
68             Server::builder()
69                 .layer(
70                     ServiceBuilder::new()
71                         .map_request(assert_right_encoding)
72                         .layer(measure_request_body_size_layer(
73                             request_bytes_counter.clone(),
74                         ))
75                         .into_inner(),
76                 )
77                 .add_service(svc)
78                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
79                 .await
80                 .unwrap();
81         }
82     });
83 
84     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
85 
86     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
87     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
88     let req = Request::new(Box::pin(stream));
89 
90     client.compress_input_client_stream(req).await.unwrap();
91 
92     let bytes_sent = request_bytes_counter.load(SeqCst);
93     assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
94 }
95 
96 #[tokio::test(flavor = "multi_thread")]
97 async fn client_enabled_server_disabled() {
98     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
99 
100     let svc = test_server::TestServer::new(Svc::default());
101 
102     tokio::spawn(async move {
103         Server::builder()
104             .add_service(svc)
105             .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
106             .await
107             .unwrap();
108     });
109 
110     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
111         .send_compressed(CompressionEncoding::Gzip);
112 
113     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
114     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
115     let req = Request::new(Box::pin(stream));
116 
117     let status = client.compress_input_client_stream(req).await.unwrap_err();
118 
119     assert_eq!(status.code(), tonic::Code::Unimplemented);
120     assert_eq!(
121         status.message(),
122         "Content is compressed with `gzip` which isn't supported"
123     );
124 }
125 
126 #[tokio::test(flavor = "multi_thread")]
127 async fn compressing_response_from_client_stream() {
128     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
129 
130     let svc =
131         test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
132 
133     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
134 
135     tokio::spawn({
136         let response_bytes_counter = response_bytes_counter.clone();
137         async move {
138             Server::builder()
139                 .layer(
140                     ServiceBuilder::new()
141                         .layer(MapResponseBodyLayer::new(move |body| {
142                             util::CountBytesBody {
143                                 inner: body,
144                                 counter: response_bytes_counter.clone(),
145                             }
146                         }))
147                         .into_inner(),
148                 )
149                 .add_service(svc)
150                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
151                 .await
152                 .unwrap();
153         }
154     });
155 
156     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
157         .accept_compressed(CompressionEncoding::Gzip);
158 
159     let req = Request::new(Box::pin(tokio_stream::empty()));
160 
161     let res = client.compress_output_client_stream(req).await.unwrap();
162     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
163     let bytes_sent = response_bytes_counter.load(SeqCst);
164     assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
165 }
166