1 use super::*;
2 use http_body::Body;
3 use tonic::codec::CompressionEncoding;
4
5 util::parametrized_tests! {
6 client_enabled_server_enabled,
7 zstd: CompressionEncoding::Zstd,
8 gzip: CompressionEncoding::Gzip,
9 deflate: CompressionEncoding::Deflate,
10 }
11
12 #[allow(dead_code)]
client_enabled_server_enabled(encoding: CompressionEncoding)13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15
16 let svc = test_server::TestServer::new(Svc::default())
17 .accept_compressed(encoding)
18 .send_compressed(encoding);
19
20 let request_bytes_counter = Arc::new(AtomicUsize::new(0));
21 let response_bytes_counter = Arc::new(AtomicUsize::new(0));
22
23 #[derive(Clone)]
24 pub struct AssertRightEncoding {
25 encoding: CompressionEncoding,
26 }
27
28 #[allow(dead_code)]
29 impl AssertRightEncoding {
30 pub fn new(encoding: CompressionEncoding) -> Self {
31 Self { encoding }
32 }
33
34 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
35 let expected = match self.encoding {
36 CompressionEncoding::Gzip => "gzip",
37 CompressionEncoding::Zstd => "zstd",
38 CompressionEncoding::Deflate => "deflate",
39 _ => panic!("unexpected encoding {:?}", self.encoding),
40 };
41 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
42
43 req
44 }
45 }
46
47 tokio::spawn({
48 let request_bytes_counter = request_bytes_counter.clone();
49 let response_bytes_counter = response_bytes_counter.clone();
50 async move {
51 Server::builder()
52 .layer(
53 ServiceBuilder::new()
54 .map_request(move |req| {
55 AssertRightEncoding::new(encoding).clone().call(req)
56 })
57 .layer(measure_request_body_size_layer(
58 request_bytes_counter.clone(),
59 ))
60 .layer(MapResponseBodyLayer::new(move |body| {
61 util::CountBytesBody {
62 inner: body,
63 counter: response_bytes_counter.clone(),
64 }
65 }))
66 .into_inner(),
67 )
68 .add_service(svc)
69 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
70 .await
71 .unwrap();
72 }
73 });
74
75 let mut client = test_client::TestClient::new(mock_io_channel(client).await)
76 .send_compressed(encoding)
77 .accept_compressed(encoding);
78
79 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
80 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
81 let req = Request::new(stream);
82
83 let res = client
84 .compress_input_output_bidirectional_stream(req)
85 .await
86 .unwrap();
87
88 let expected = match encoding {
89 CompressionEncoding::Gzip => "gzip",
90 CompressionEncoding::Zstd => "zstd",
91 CompressionEncoding::Deflate => "deflate",
92 _ => panic!("unexpected encoding {:?}", encoding),
93 };
94 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
95
96 let mut stream: Streaming<SomeData> = res.into_inner();
97
98 stream
99 .next()
100 .await
101 .expect("stream empty")
102 .expect("item was error");
103
104 stream
105 .next()
106 .await
107 .expect("stream empty")
108 .expect("item was error");
109
110 assert!(request_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
111 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
112 }
113