1 use super::*; 2 use tonic::codec::CompressionEncoding; 3 4 util::parametrized_tests! { 5 client_enabled_server_enabled, 6 zstd: CompressionEncoding::Zstd, 7 gzip: CompressionEncoding::Gzip, 8 deflate: CompressionEncoding::Deflate, 9 } 10 11 #[allow(dead_code)] 12 async fn client_enabled_server_enabled(encoding: CompressionEncoding) { 13 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 14 15 #[derive(Clone, Copy)] 16 struct AssertCorrectAcceptEncoding<S> { 17 service: S, 18 encoding: CompressionEncoding, 19 } 20 21 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S> 22 where 23 S: Service<http::Request<B>>, 24 { 25 type Response = S::Response; 26 type Error = S::Error; 27 type Future = S::Future; 28 29 fn poll_ready( 30 &mut self, 31 cx: &mut std::task::Context<'_>, 32 ) -> std::task::Poll<Result<(), Self::Error>> { 33 self.service.poll_ready(cx) 34 } 35 36 fn call(&mut self, req: http::Request<B>) -> Self::Future { 37 let expected = match self.encoding { 38 CompressionEncoding::Gzip => "gzip", 39 CompressionEncoding::Zstd => "zstd", 40 CompressionEncoding::Deflate => "deflate", 41 _ => panic!("unexpected encoding {:?}", self.encoding), 42 }; 43 assert_eq!( 44 req.headers() 45 .get("grpc-accept-encoding") 46 .unwrap() 47 .to_str() 48 .unwrap(), 49 format!("{},identity", expected) 50 ); 51 self.service.call(req) 52 } 53 } 54 55 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); 56 57 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 58 59 tokio::spawn({ 60 let response_bytes_counter = response_bytes_counter.clone(); 61 async move { 62 Server::builder() 63 .layer( 64 ServiceBuilder::new() 65 .layer(layer_fn(|service| AssertCorrectAcceptEncoding { 66 service, 67 encoding, 68 })) 69 .layer(MapResponseBodyLayer::new(move |body| { 70 util::CountBytesBody { 71 inner: body, 72 counter: response_bytes_counter.clone(), 73 } 74 })) 75 .into_inner(), 76 ) 77 .add_service(svc) 78 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 79 .await 80 .unwrap(); 81 } 82 }); 83 84 let mut client = 85 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 86 87 let expected = match encoding { 88 CompressionEncoding::Gzip => "gzip", 89 CompressionEncoding::Zstd => "zstd", 90 CompressionEncoding::Deflate => "deflate", 91 _ => panic!("unexpected encoding {:?}", encoding), 92 }; 93 94 for _ in 0..3 { 95 let res = client.compress_output_unary(()).await.unwrap(); 96 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); 97 let bytes_sent = response_bytes_counter.load(SeqCst); 98 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 99 } 100 } 101 102 util::parametrized_tests! { 103 client_enabled_server_disabled, 104 zstd: CompressionEncoding::Zstd, 105 gzip: CompressionEncoding::Gzip, 106 deflate: CompressionEncoding::Deflate, 107 } 108 109 #[allow(dead_code)] 110 async fn client_enabled_server_disabled(encoding: CompressionEncoding) { 111 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 112 113 let svc = test_server::TestServer::new(Svc::default()); 114 115 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 116 117 tokio::spawn({ 118 let response_bytes_counter = response_bytes_counter.clone(); 119 async move { 120 Server::builder() 121 // no compression enable on the server so responses should not be compressed 122 .layer( 123 ServiceBuilder::new() 124 .layer(MapResponseBodyLayer::new(move |body| { 125 util::CountBytesBody { 126 inner: body, 127 counter: response_bytes_counter.clone(), 128 } 129 })) 130 .into_inner(), 131 ) 132 .add_service(svc) 133 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)])) 134 .await 135 .unwrap(); 136 } 137 }); 138 139 let mut client = 140 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 141 142 let res = client.compress_output_unary(()).await.unwrap(); 143 144 assert!(res.metadata().get("grpc-encoding").is_none()); 145 146 let bytes_sent = response_bytes_counter.load(SeqCst); 147 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 148 } 149 150 #[tokio::test(flavor = "multi_thread")] 151 async fn client_enabled_server_disabled_multi_encoding() { 152 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 153 154 let svc = test_server::TestServer::new(Svc::default()); 155 156 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 157 158 tokio::spawn({ 159 let response_bytes_counter = response_bytes_counter.clone(); 160 async move { 161 Server::builder() 162 // no compression enable on the server so responses should not be compressed 163 .layer( 164 ServiceBuilder::new() 165 .layer(MapResponseBodyLayer::new(move |body| { 166 util::CountBytesBody { 167 inner: body, 168 counter: response_bytes_counter.clone(), 169 } 170 })) 171 .into_inner(), 172 ) 173 .add_service(svc) 174 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 175 .await 176 .unwrap(); 177 } 178 }); 179 180 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 181 .accept_compressed(CompressionEncoding::Gzip) 182 .accept_compressed(CompressionEncoding::Zstd) 183 .accept_compressed(CompressionEncoding::Deflate); 184 185 let res = client.compress_output_unary(()).await.unwrap(); 186 187 assert!(res.metadata().get("grpc-encoding").is_none()); 188 189 let bytes_sent = response_bytes_counter.load(SeqCst); 190 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 191 } 192 193 util::parametrized_tests! { 194 client_disabled, 195 zstd: CompressionEncoding::Zstd, 196 gzip: CompressionEncoding::Gzip, 197 deflate: CompressionEncoding::Deflate, 198 } 199 200 #[allow(dead_code)] 201 async fn client_disabled(encoding: CompressionEncoding) { 202 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 203 204 #[derive(Clone, Copy)] 205 struct AssertCorrectAcceptEncoding<S>(S); 206 207 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S> 208 where 209 S: Service<http::Request<B>>, 210 { 211 type Response = S::Response; 212 type Error = S::Error; 213 type Future = S::Future; 214 215 fn poll_ready( 216 &mut self, 217 cx: &mut std::task::Context<'_>, 218 ) -> std::task::Poll<Result<(), Self::Error>> { 219 self.0.poll_ready(cx) 220 } 221 222 fn call(&mut self, req: http::Request<B>) -> Self::Future { 223 assert!(req.headers().get("grpc-accept-encoding").is_none()); 224 self.0.call(req) 225 } 226 } 227 228 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); 229 230 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 231 232 tokio::spawn({ 233 let response_bytes_counter = response_bytes_counter.clone(); 234 async move { 235 Server::builder() 236 .layer( 237 ServiceBuilder::new() 238 .layer(layer_fn(AssertCorrectAcceptEncoding)) 239 .layer(MapResponseBodyLayer::new(move |body| { 240 util::CountBytesBody { 241 inner: body, 242 counter: response_bytes_counter.clone(), 243 } 244 })) 245 .into_inner(), 246 ) 247 .add_service(svc) 248 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 249 .await 250 .unwrap(); 251 } 252 }); 253 254 let mut client = test_client::TestClient::new(mock_io_channel(client).await); 255 256 let res = client.compress_output_unary(()).await.unwrap(); 257 258 assert!(res.metadata().get("grpc-encoding").is_none()); 259 260 let bytes_sent = response_bytes_counter.load(SeqCst); 261 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 262 } 263 264 util::parametrized_tests! { 265 server_replying_with_unsupported_encoding, 266 zstd: CompressionEncoding::Zstd, 267 gzip: CompressionEncoding::Gzip, 268 deflate: CompressionEncoding::Deflate, 269 } 270 271 #[allow(dead_code)] 272 async fn server_replying_with_unsupported_encoding(encoding: CompressionEncoding) { 273 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 274 275 let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); 276 277 fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> { 278 response 279 .headers_mut() 280 .insert("grpc-encoding", "br".parse().unwrap()); 281 response 282 } 283 284 tokio::spawn(async move { 285 Server::builder() 286 .layer( 287 ServiceBuilder::new() 288 .map_response(add_weird_content_encoding) 289 .into_inner(), 290 ) 291 .add_service(svc) 292 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 293 .await 294 .unwrap(); 295 }); 296 297 let mut client = 298 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 299 let status: Status = client.compress_output_unary(()).await.unwrap_err(); 300 301 assert_eq!(status.code(), tonic::Code::Unimplemented); 302 assert_eq!( 303 status.message(), 304 "Content is compressed with `br` which isn't supported" 305 ); 306 } 307 308 util::parametrized_tests! { 309 disabling_compression_on_single_response, 310 zstd: CompressionEncoding::Zstd, 311 gzip: CompressionEncoding::Gzip, 312 deflate: CompressionEncoding::Deflate, 313 } 314 315 #[allow(dead_code)] 316 async fn disabling_compression_on_single_response(encoding: CompressionEncoding) { 317 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 318 319 let svc = test_server::TestServer::new(Svc { 320 disable_compressing_on_response: true, 321 }) 322 .send_compressed(encoding); 323 324 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 325 326 tokio::spawn({ 327 let response_bytes_counter = response_bytes_counter.clone(); 328 async move { 329 Server::builder() 330 .layer( 331 ServiceBuilder::new() 332 .layer(MapResponseBodyLayer::new(move |body| { 333 util::CountBytesBody { 334 inner: body, 335 counter: response_bytes_counter.clone(), 336 } 337 })) 338 .into_inner(), 339 ) 340 .add_service(svc) 341 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 342 .await 343 .unwrap(); 344 } 345 }); 346 347 let mut client = 348 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 349 350 let res = client.compress_output_unary(()).await.unwrap(); 351 352 let expected = match encoding { 353 CompressionEncoding::Gzip => "gzip", 354 CompressionEncoding::Zstd => "zstd", 355 CompressionEncoding::Deflate => "deflate", 356 _ => panic!("unexpected encoding {:?}", encoding), 357 }; 358 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); 359 360 let bytes_sent = response_bytes_counter.load(SeqCst); 361 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 362 } 363 364 util::parametrized_tests! { 365 disabling_compression_on_response_but_keeping_compression_on_stream, 366 zstd: CompressionEncoding::Zstd, 367 gzip: CompressionEncoding::Gzip, 368 deflate: CompressionEncoding::Deflate, 369 } 370 371 #[allow(dead_code)] 372 async fn disabling_compression_on_response_but_keeping_compression_on_stream( 373 encoding: CompressionEncoding, 374 ) { 375 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 376 377 let svc = test_server::TestServer::new(Svc { 378 disable_compressing_on_response: true, 379 }) 380 .send_compressed(encoding); 381 382 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 383 384 tokio::spawn({ 385 let response_bytes_counter = response_bytes_counter.clone(); 386 async move { 387 Server::builder() 388 .layer( 389 ServiceBuilder::new() 390 .layer(MapResponseBodyLayer::new(move |body| { 391 util::CountBytesBody { 392 inner: body, 393 counter: response_bytes_counter.clone(), 394 } 395 })) 396 .into_inner(), 397 ) 398 .add_service(svc) 399 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 400 .await 401 .unwrap(); 402 } 403 }); 404 405 let mut client = 406 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 407 408 let res = client.compress_output_server_stream(()).await.unwrap(); 409 410 let expected = match encoding { 411 CompressionEncoding::Gzip => "gzip", 412 CompressionEncoding::Zstd => "zstd", 413 CompressionEncoding::Deflate => "deflate", 414 _ => panic!("unexpected encoding {:?}", encoding), 415 }; 416 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); 417 418 let mut stream: Streaming<SomeData> = res.into_inner(); 419 420 stream 421 .next() 422 .await 423 .expect("stream empty") 424 .expect("item was error"); 425 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 426 427 stream 428 .next() 429 .await 430 .expect("stream empty") 431 .expect("item was error"); 432 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 433 } 434 435 util::parametrized_tests! { 436 disabling_compression_on_response_from_client_stream, 437 zstd: CompressionEncoding::Zstd, 438 gzip: CompressionEncoding::Gzip, 439 deflate: CompressionEncoding::Deflate, 440 } 441 442 #[allow(dead_code)] 443 async fn disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding) { 444 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 445 446 let svc = test_server::TestServer::new(Svc { 447 disable_compressing_on_response: true, 448 }) 449 .send_compressed(encoding); 450 451 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 452 453 tokio::spawn({ 454 let response_bytes_counter = response_bytes_counter.clone(); 455 async move { 456 Server::builder() 457 .layer( 458 ServiceBuilder::new() 459 .layer(MapResponseBodyLayer::new(move |body| { 460 util::CountBytesBody { 461 inner: body, 462 counter: response_bytes_counter.clone(), 463 } 464 })) 465 .into_inner(), 466 ) 467 .add_service(svc) 468 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 469 .await 470 .unwrap(); 471 } 472 }); 473 474 let mut client = 475 test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); 476 477 let req = Request::new(Box::pin(tokio_stream::empty())); 478 479 let res = client.compress_output_client_stream(req).await.unwrap(); 480 481 let expected = match encoding { 482 CompressionEncoding::Gzip => "gzip", 483 CompressionEncoding::Zstd => "zstd", 484 CompressionEncoding::Deflate => "deflate", 485 _ => panic!("unexpected encoding {:?}", encoding), 486 }; 487 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); 488 let bytes_sent = response_bytes_counter.load(SeqCst); 489 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 490 } 491