1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4 
5 util::parametrized_tests! {
6     client_enabled_server_enabled,
7     zstd: CompressionEncoding::Zstd,
8     gzip: CompressionEncoding::Gzip,
9     deflate: CompressionEncoding::Deflate,
10 }
11 
12 #[allow(dead_code)]
13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15 
16     let svc = test_server::TestServer::new(Svc::default())
17         .accept_compressed(encoding)
18         .send_compressed(encoding);
19 
20     let request_bytes_counter = Arc::new(AtomicUsize::new(0));
21     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
22 
23     #[derive(Clone)]
24     pub struct AssertRightEncoding {
25         encoding: CompressionEncoding,
26     }
27 
28     #[allow(dead_code)]
29     impl AssertRightEncoding {
30         pub fn new(encoding: CompressionEncoding) -> Self {
31             Self { encoding }
32         }
33 
34         pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
35             let expected = match self.encoding {
36                 CompressionEncoding::Gzip => "gzip",
37                 CompressionEncoding::Zstd => "zstd",
38                 CompressionEncoding::Deflate => "deflate",
39                 _ => panic!("unexpected encoding {:?}", self.encoding),
40             };
41             assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
42 
43             req
44         }
45     }
46 
47     tokio::spawn({
48         let request_bytes_counter = request_bytes_counter.clone();
49         let response_bytes_counter = response_bytes_counter.clone();
50         async move {
51             Server::builder()
52                 .layer(
53                     ServiceBuilder::new()
54                         .map_request(move |req| {
55                             AssertRightEncoding::new(encoding).clone().call(req)
56                         })
57                         .layer(measure_request_body_size_layer(
58                             request_bytes_counter.clone(),
59                         ))
60                         .layer(MapResponseBodyLayer::new(move |body| {
61                             util::CountBytesBody {
62                                 inner: body,
63                                 counter: response_bytes_counter.clone(),
64                             }
65                         }))
66                         .into_inner(),
67                 )
68                 .add_service(svc)
69                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
70                 .await
71                 .unwrap();
72         }
73     });
74 
75     let mut client = test_client::TestClient::new(mock_io_channel(client).await)
76         .send_compressed(encoding)
77         .accept_compressed(encoding);
78 
79     let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
80     let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
81     let req = Request::new(stream);
82 
83     let res = client
84         .compress_input_output_bidirectional_stream(req)
85         .await
86         .unwrap();
87 
88     let expected = match encoding {
89         CompressionEncoding::Gzip => "gzip",
90         CompressionEncoding::Zstd => "zstd",
91         CompressionEncoding::Deflate => "deflate",
92         _ => panic!("unexpected encoding {:?}", encoding),
93     };
94     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
95 
96     let mut stream: Streaming<SomeData> = res.into_inner();
97 
98     stream
99         .next()
100         .await
101         .expect("stream empty")
102         .expect("item was error");
103 
104     stream
105         .next()
106         .await
107         .expect("stream empty")
108         .expect("item was error");
109 
110     assert!(request_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
111     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
112 }
113