1 use super::*; 2 use http_body::Body; 3 use tonic::codec::CompressionEncoding; 4 5 util::parametrized_tests! { 6 client_enabled_server_enabled, 7 zstd: CompressionEncoding::Zstd, 8 gzip: CompressionEncoding::Gzip, 9 deflate: CompressionEncoding::Deflate, 10 } 11 12 #[allow(dead_code)] 13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) { 14 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 15 16 let svc = test_server::TestServer::new(Svc::default()) 17 .accept_compressed(encoding) 18 .send_compressed(encoding); 19 20 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 21 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 22 23 #[derive(Clone)] 24 pub struct AssertRightEncoding { 25 encoding: CompressionEncoding, 26 } 27 28 #[allow(dead_code)] 29 impl AssertRightEncoding { 30 pub fn new(encoding: CompressionEncoding) -> Self { 31 Self { encoding } 32 } 33 34 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> { 35 let expected = match self.encoding { 36 CompressionEncoding::Gzip => "gzip", 37 CompressionEncoding::Zstd => "zstd", 38 CompressionEncoding::Deflate => "deflate", 39 _ => panic!("unexpected encoding {:?}", self.encoding), 40 }; 41 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); 42 43 req 44 } 45 } 46 47 tokio::spawn({ 48 let request_bytes_counter = request_bytes_counter.clone(); 49 let response_bytes_counter = response_bytes_counter.clone(); 50 async move { 51 Server::builder() 52 .layer( 53 ServiceBuilder::new() 54 .map_request(move |req| { 55 AssertRightEncoding::new(encoding).clone().call(req) 56 }) 57 .layer(measure_request_body_size_layer( 58 request_bytes_counter.clone(), 59 )) 60 .layer(MapResponseBodyLayer::new(move |body| { 61 util::CountBytesBody { 62 inner: body, 63 counter: response_bytes_counter.clone(), 64 } 65 })) 66 .into_inner(), 67 ) 68 .add_service(svc) 69 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 70 .await 71 .unwrap(); 72 } 73 }); 74 75 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 76 .send_compressed(encoding) 77 .accept_compressed(encoding); 78 79 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 80 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 81 let req = Request::new(stream); 82 83 let res = client 84 .compress_input_output_bidirectional_stream(req) 85 .await 86 .unwrap(); 87 88 let expected = match encoding { 89 CompressionEncoding::Gzip => "gzip", 90 CompressionEncoding::Zstd => "zstd", 91 CompressionEncoding::Deflate => "deflate", 92 _ => panic!("unexpected encoding {:?}", encoding), 93 }; 94 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); 95 96 let mut stream: Streaming<SomeData> = res.into_inner(); 97 98 stream 99 .next() 100 .await 101 .expect("stream empty") 102 .expect("item was error"); 103 104 stream 105 .next() 106 .await 107 .expect("stream empty") 108 .expect("item was error"); 109 110 assert!(request_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 111 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 112 } 113