1 use super::*; 2 use http_body::Body; 3 use tonic::codec::CompressionEncoding; 4 5 util::parametrized_tests! { 6 client_enabled_server_enabled, 7 zstd: CompressionEncoding::Zstd, 8 gzip: CompressionEncoding::Gzip, 9 deflate: CompressionEncoding::Deflate, 10 } 11 12 #[allow(dead_code)] 13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) { 14 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 15 16 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); 17 18 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 19 20 #[derive(Clone)] 21 pub struct AssertRightEncoding { 22 encoding: CompressionEncoding, 23 } 24 25 #[allow(dead_code)] 26 impl AssertRightEncoding { 27 pub fn new(encoding: CompressionEncoding) -> Self { 28 Self { encoding } 29 } 30 31 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> { 32 let expected = match self.encoding { 33 CompressionEncoding::Gzip => "gzip", 34 CompressionEncoding::Zstd => "zstd", 35 CompressionEncoding::Deflate => "deflate", 36 _ => panic!("unexpected encoding {:?}", self.encoding), 37 }; 38 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); 39 40 req 41 } 42 } 43 44 tokio::spawn({ 45 let request_bytes_counter = request_bytes_counter.clone(); 46 async move { 47 Server::builder() 48 .layer( 49 ServiceBuilder::new() 50 .map_request(move |req| { 51 AssertRightEncoding::new(encoding).clone().call(req) 52 }) 53 .layer(measure_request_body_size_layer( 54 request_bytes_counter.clone(), 55 )) 56 .into_inner(), 57 ) 58 .add_service(svc) 59 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 60 .await 61 .unwrap(); 62 } 63 }); 64 65 let mut client = 66 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 67 68 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 69 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 70 let req = Request::new(Box::pin(stream)); 71 72 client.compress_input_client_stream(req).await.unwrap(); 73 74 let bytes_sent = request_bytes_counter.load(SeqCst); 75 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 76 } 77 78 util::parametrized_tests! { 79 client_disabled_server_enabled, 80 zstd: CompressionEncoding::Zstd, 81 gzip: CompressionEncoding::Gzip, 82 deflate: CompressionEncoding::Deflate, 83 } 84 85 #[allow(dead_code)] 86 async fn client_disabled_server_enabled(encoding: CompressionEncoding) { 87 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 88 89 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); 90 91 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 92 93 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> { 94 assert!(req.headers().get("grpc-encoding").is_none()); 95 req 96 } 97 98 tokio::spawn({ 99 let request_bytes_counter = request_bytes_counter.clone(); 100 async move { 101 Server::builder() 102 .layer( 103 ServiceBuilder::new() 104 .map_request(assert_right_encoding) 105 .layer(measure_request_body_size_layer( 106 request_bytes_counter.clone(), 107 )) 108 .into_inner(), 109 ) 110 .add_service(svc) 111 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 112 .await 113 .unwrap(); 114 } 115 }); 116 117 let mut client = test_client::TestClient::new(mock_io_channel(client).await); 118 119 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 120 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 121 let req = Request::new(Box::pin(stream)); 122 123 client.compress_input_client_stream(req).await.unwrap(); 124 125 let bytes_sent = request_bytes_counter.load(SeqCst); 126 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 127 } 128 129 util::parametrized_tests! { 130 client_enabled_server_disabled, 131 zstd: CompressionEncoding::Zstd, 132 gzip: CompressionEncoding::Gzip, 133 deflate: CompressionEncoding::Deflate, 134 } 135 136 #[allow(dead_code)] 137 async fn client_enabled_server_disabled(encoding: CompressionEncoding) { 138 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 139 140 let svc = test_server::TestServer::new(Svc::default()); 141 142 tokio::spawn(async move { 143 Server::builder() 144 .add_service(svc) 145 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 146 .await 147 .unwrap(); 148 }); 149 150 let mut client = 151 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 152 153 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 154 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 155 let req = Request::new(Box::pin(stream)); 156 157 let status = client.compress_input_client_stream(req).await.unwrap_err(); 158 159 assert_eq!(status.code(), tonic::Code::Unimplemented); 160 let expected = match encoding { 161 CompressionEncoding::Gzip => "gzip", 162 CompressionEncoding::Zstd => "zstd", 163 CompressionEncoding::Deflate => "deflate", 164 _ => panic!("unexpected encoding {:?}", encoding), 165 }; 166 assert_eq!( 167 status.message(), 168 format!( 169 "Content is compressed with `{}` which isn't supported", 170 expected 171 ) 172 ); 173 } 174 175 util::parametrized_tests! { 176 compressing_response_from_client_stream, 177 zstd: CompressionEncoding::Zstd, 178 gzip: CompressionEncoding::Gzip, 179 deflate: CompressionEncoding::Deflate, 180 } 181 182 #[allow(dead_code)] 183 async fn compressing_response_from_client_stream(encoding: CompressionEncoding) { 184 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 185 186 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); 187 188 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 189 190 tokio::spawn({ 191 let response_bytes_counter = response_bytes_counter.clone(); 192 async move { 193 Server::builder() 194 .layer( 195 ServiceBuilder::new() 196 .layer(MapResponseBodyLayer::new(move |body| { 197 util::CountBytesBody { 198 inner: body, 199 counter: response_bytes_counter.clone(), 200 } 201 })) 202 .into_inner(), 203 ) 204 .add_service(svc) 205 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 206 .await 207 .unwrap(); 208 } 209 }); 210 211 let mut client = 212 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 213 214 let req = Request::new(Box::pin(tokio_stream::empty())); 215 216 let res = client.compress_output_client_stream(req).await.unwrap(); 217 let expected = match encoding { 218 CompressionEncoding::Gzip => "gzip", 219 CompressionEncoding::Zstd => "zstd", 220 CompressionEncoding::Deflate => "deflate", 221 _ => panic!("unexpected encoding {:?}", encoding), 222 }; 223 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); 224 let bytes_sent = response_bytes_counter.load(SeqCst); 225 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 226 } 227