1 use super::*; 2 use tonic::codec::CompressionEncoding; 3 4 #[tokio::test(flavor = "multi_thread")] 5 async fn client_enabled_server_enabled() { 6 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 7 8 #[derive(Clone, Copy)] 9 struct AssertCorrectAcceptEncoding<S>(S); 10 11 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S> 12 where 13 S: Service<http::Request<B>>, 14 { 15 type Response = S::Response; 16 type Error = S::Error; 17 type Future = S::Future; 18 19 fn poll_ready( 20 &mut self, 21 cx: &mut std::task::Context<'_>, 22 ) -> std::task::Poll<Result<(), Self::Error>> { 23 self.0.poll_ready(cx) 24 } 25 26 fn call(&mut self, req: http::Request<B>) -> Self::Future { 27 assert_eq!( 28 req.headers().get("grpc-accept-encoding").unwrap(), 29 "gzip,identity" 30 ); 31 self.0.call(req) 32 } 33 } 34 35 let svc = 36 test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); 37 38 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 39 40 tokio::spawn({ 41 let response_bytes_counter = response_bytes_counter.clone(); 42 async move { 43 Server::builder() 44 .layer( 45 ServiceBuilder::new() 46 .layer(layer_fn(AssertCorrectAcceptEncoding)) 47 .layer(MapResponseBodyLayer::new(move |body| { 48 util::CountBytesBody { 49 inner: body, 50 counter: response_bytes_counter.clone(), 51 } 52 })) 53 .into_inner(), 54 ) 55 .add_service(svc) 56 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 57 .await 58 .unwrap(); 59 } 60 }); 61 62 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 63 .accept_compressed(CompressionEncoding::Gzip); 64 65 for _ in 0..3 { 66 let res = client.compress_output_unary(()).await.unwrap(); 67 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 68 let bytes_sent = response_bytes_counter.load(SeqCst); 69 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 70 } 71 } 72 73 #[tokio::test(flavor = "multi_thread")] 74 async fn client_enabled_server_disabled() { 75 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 76 77 let svc = test_server::TestServer::new(Svc::default()); 78 79 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 80 81 tokio::spawn({ 82 let response_bytes_counter = response_bytes_counter.clone(); 83 async move { 84 Server::builder() 85 // no compression enable on the server so responses should not be compressed 86 .layer( 87 ServiceBuilder::new() 88 .layer(MapResponseBodyLayer::new(move |body| { 89 util::CountBytesBody { 90 inner: body, 91 counter: response_bytes_counter.clone(), 92 } 93 })) 94 .into_inner(), 95 ) 96 .add_service(svc) 97 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 98 .await 99 .unwrap(); 100 } 101 }); 102 103 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 104 .accept_compressed(CompressionEncoding::Gzip); 105 106 let res = client.compress_output_unary(()).await.unwrap(); 107 108 assert!(res.metadata().get("grpc-encoding").is_none()); 109 110 let bytes_sent = response_bytes_counter.load(SeqCst); 111 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 112 } 113 114 #[tokio::test(flavor = "multi_thread")] 115 async fn client_disabled() { 116 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 117 118 #[derive(Clone, Copy)] 119 struct AssertCorrectAcceptEncoding<S>(S); 120 121 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S> 122 where 123 S: Service<http::Request<B>>, 124 { 125 type Response = S::Response; 126 type Error = S::Error; 127 type Future = S::Future; 128 129 fn poll_ready( 130 &mut self, 131 cx: &mut std::task::Context<'_>, 132 ) -> std::task::Poll<Result<(), Self::Error>> { 133 self.0.poll_ready(cx) 134 } 135 136 fn call(&mut self, req: http::Request<B>) -> Self::Future { 137 assert!(req.headers().get("grpc-accept-encoding").is_none()); 138 self.0.call(req) 139 } 140 } 141 142 let svc = 143 test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); 144 145 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 146 147 tokio::spawn({ 148 let response_bytes_counter = response_bytes_counter.clone(); 149 async move { 150 Server::builder() 151 .layer( 152 ServiceBuilder::new() 153 .layer(layer_fn(AssertCorrectAcceptEncoding)) 154 .layer(MapResponseBodyLayer::new(move |body| { 155 util::CountBytesBody { 156 inner: body, 157 counter: response_bytes_counter.clone(), 158 } 159 })) 160 .into_inner(), 161 ) 162 .add_service(svc) 163 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 164 .await 165 .unwrap(); 166 } 167 }); 168 169 let mut client = test_client::TestClient::new(mock_io_channel(client).await); 170 171 let res = client.compress_output_unary(()).await.unwrap(); 172 173 assert!(res.metadata().get("grpc-encoding").is_none()); 174 175 let bytes_sent = response_bytes_counter.load(SeqCst); 176 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 177 } 178 179 #[tokio::test(flavor = "multi_thread")] 180 async fn server_replying_with_unsupported_encoding() { 181 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 182 183 let svc = 184 test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); 185 186 fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> { 187 response 188 .headers_mut() 189 .insert("grpc-encoding", "br".parse().unwrap()); 190 response 191 } 192 193 tokio::spawn(async move { 194 Server::builder() 195 .layer( 196 ServiceBuilder::new() 197 .map_response(add_weird_content_encoding) 198 .into_inner(), 199 ) 200 .add_service(svc) 201 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 202 .await 203 .unwrap(); 204 }); 205 206 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 207 .accept_compressed(CompressionEncoding::Gzip); 208 let status: Status = client.compress_output_unary(()).await.unwrap_err(); 209 210 assert_eq!(status.code(), tonic::Code::Unimplemented); 211 assert_eq!( 212 status.message(), 213 "Content is compressed with `br` which isn't supported" 214 ); 215 } 216 217 #[tokio::test(flavor = "multi_thread")] 218 async fn disabling_compression_on_single_response() { 219 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 220 221 let svc = test_server::TestServer::new(Svc { 222 disable_compressing_on_response: true, 223 }) 224 .send_compressed(CompressionEncoding::Gzip); 225 226 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 227 228 tokio::spawn({ 229 let response_bytes_counter = response_bytes_counter.clone(); 230 async move { 231 Server::builder() 232 .layer( 233 ServiceBuilder::new() 234 .layer(MapResponseBodyLayer::new(move |body| { 235 util::CountBytesBody { 236 inner: body, 237 counter: response_bytes_counter.clone(), 238 } 239 })) 240 .into_inner(), 241 ) 242 .add_service(svc) 243 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 244 .await 245 .unwrap(); 246 } 247 }); 248 249 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 250 .accept_compressed(CompressionEncoding::Gzip); 251 252 let res = client.compress_output_unary(()).await.unwrap(); 253 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 254 let bytes_sent = response_bytes_counter.load(SeqCst); 255 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 256 } 257 258 #[tokio::test(flavor = "multi_thread")] 259 async fn disabling_compression_on_response_but_keeping_compression_on_stream() { 260 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 261 262 let svc = test_server::TestServer::new(Svc { 263 disable_compressing_on_response: true, 264 }) 265 .send_compressed(CompressionEncoding::Gzip); 266 267 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 268 269 tokio::spawn({ 270 let response_bytes_counter = response_bytes_counter.clone(); 271 async move { 272 Server::builder() 273 .layer( 274 ServiceBuilder::new() 275 .layer(MapResponseBodyLayer::new(move |body| { 276 util::CountBytesBody { 277 inner: body, 278 counter: response_bytes_counter.clone(), 279 } 280 })) 281 .into_inner(), 282 ) 283 .add_service(svc) 284 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 285 .await 286 .unwrap(); 287 } 288 }); 289 290 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 291 .accept_compressed(CompressionEncoding::Gzip); 292 293 let res = client.compress_output_server_stream(()).await.unwrap(); 294 295 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 296 297 let mut stream: Streaming<SomeData> = res.into_inner(); 298 299 stream 300 .next() 301 .await 302 .expect("stream empty") 303 .expect("item was error"); 304 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 305 306 stream 307 .next() 308 .await 309 .expect("stream empty") 310 .expect("item was error"); 311 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 312 } 313 314 #[tokio::test(flavor = "multi_thread")] 315 async fn disabling_compression_on_response_from_client_stream() { 316 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 317 318 let svc = test_server::TestServer::new(Svc { 319 disable_compressing_on_response: true, 320 }) 321 .send_compressed(CompressionEncoding::Gzip); 322 323 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 324 325 tokio::spawn({ 326 let response_bytes_counter = response_bytes_counter.clone(); 327 async move { 328 Server::builder() 329 .layer( 330 ServiceBuilder::new() 331 .layer(MapResponseBodyLayer::new(move |body| { 332 util::CountBytesBody { 333 inner: body, 334 counter: response_bytes_counter.clone(), 335 } 336 })) 337 .into_inner(), 338 ) 339 .add_service(svc) 340 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 341 .await 342 .unwrap(); 343 } 344 }); 345 346 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 347 .accept_compressed(CompressionEncoding::Gzip); 348 349 let req = Request::new(Box::pin(tokio_stream::empty())); 350 351 let res = client.compress_output_client_stream(req).await.unwrap(); 352 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 353 let bytes_sent = response_bytes_counter.load(SeqCst); 354 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 355 } 356