1 use super::*;
2 use tonic::codec::CompressionEncoding;
3 use tonic::Streaming;
4 
5 util::parametrized_tests! {
6     client_enabled_server_enabled,
7     zstd: CompressionEncoding::Zstd,
8     gzip: CompressionEncoding::Gzip,
9     deflate: CompressionEncoding::Deflate,
10 }
11 
12 #[allow(dead_code)]
13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
14     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
15 
16     let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
17 
18     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
19 
20     tokio::spawn({
21         let response_bytes_counter = response_bytes_counter.clone();
22         async move {
23             Server::builder()
24                 .layer(
25                     ServiceBuilder::new()
26                         .layer(MapResponseBodyLayer::new(move |body| {
27                             util::CountBytesBody {
28                                 inner: body,
29                                 counter: response_bytes_counter.clone(),
30                             }
31                         }))
32                         .into_inner(),
33                 )
34                 .add_service(svc)
35                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
36                 .await
37                 .unwrap();
38         }
39     });
40 
41     let mut client =
42         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
43 
44     let res = client.compress_output_server_stream(()).await.unwrap();
45 
46     let expected = match encoding {
47         CompressionEncoding::Gzip => "gzip",
48         CompressionEncoding::Zstd => "zstd",
49         CompressionEncoding::Deflate => "deflate",
50         _ => panic!("unexpected encoding {:?}", encoding),
51     };
52     assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
53 
54     let mut stream: Streaming<SomeData> = res.into_inner();
55 
56     stream
57         .next()
58         .await
59         .expect("stream empty")
60         .expect("item was error");
61     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
62 
63     stream
64         .next()
65         .await
66         .expect("stream empty")
67         .expect("item was error");
68     assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE);
69 }
70 
71 util::parametrized_tests! {
72     client_disabled_server_enabled,
73     zstd: CompressionEncoding::Zstd,
74     gzip: CompressionEncoding::Gzip,
75 }
76 
77 #[allow(dead_code)]
78 async fn client_disabled_server_enabled(encoding: CompressionEncoding) {
79     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
80 
81     let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
82 
83     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
84 
85     tokio::spawn({
86         let response_bytes_counter = response_bytes_counter.clone();
87         async move {
88             Server::builder()
89                 .layer(
90                     ServiceBuilder::new()
91                         .layer(MapResponseBodyLayer::new(move |body| {
92                             util::CountBytesBody {
93                                 inner: body,
94                                 counter: response_bytes_counter.clone(),
95                             }
96                         }))
97                         .into_inner(),
98                 )
99                 .add_service(svc)
100                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
101                 .await
102                 .unwrap();
103         }
104     });
105 
106     let mut client = test_client::TestClient::new(mock_io_channel(client).await);
107 
108     let res = client.compress_output_server_stream(()).await.unwrap();
109 
110     assert!(res.metadata().get("grpc-encoding").is_none());
111 
112     let mut stream: Streaming<SomeData> = res.into_inner();
113 
114     stream
115         .next()
116         .await
117         .expect("stream empty")
118         .expect("item was error");
119     assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE);
120 }
121 
122 util::parametrized_tests! {
123     client_enabled_server_disabled,
124     zstd: CompressionEncoding::Zstd,
125     gzip: CompressionEncoding::Gzip,
126     deflate: CompressionEncoding::Deflate,
127 }
128 
129 #[allow(dead_code)]
130 async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
131     let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
132 
133     let svc = test_server::TestServer::new(Svc::default());
134 
135     let response_bytes_counter = Arc::new(AtomicUsize::new(0));
136 
137     tokio::spawn({
138         let response_bytes_counter = response_bytes_counter.clone();
139         async move {
140             Server::builder()
141                 .layer(
142                     ServiceBuilder::new()
143                         .layer(MapResponseBodyLayer::new(move |body| {
144                             util::CountBytesBody {
145                                 inner: body,
146                                 counter: response_bytes_counter.clone(),
147                             }
148                         }))
149                         .into_inner(),
150                 )
151                 .add_service(svc)
152                 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server)))
153                 .await
154                 .unwrap();
155         }
156     });
157 
158     let mut client =
159         test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
160 
161     let res = client.compress_output_server_stream(()).await.unwrap();
162 
163     assert!(res.metadata().get("grpc-encoding").is_none());
164 
165     let mut stream: Streaming<SomeData> = res.into_inner();
166 
167     stream
168         .next()
169         .await
170         .expect("stream empty")
171         .expect("item was error");
172     assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE);
173 }
174