1 use super::*; 2 use http_body::Body as _; 3 use tonic::codec::CompressionEncoding; 4 5 #[tokio::test(flavor = "multi_thread")] 6 async fn client_enabled_server_enabled() { 7 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 8 9 let svc = 10 test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); 11 12 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 13 14 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> { 15 assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); 16 req 17 } 18 19 tokio::spawn({ 20 let request_bytes_counter = request_bytes_counter.clone(); 21 async move { 22 Server::builder() 23 .layer( 24 ServiceBuilder::new() 25 .map_request(assert_right_encoding) 26 .layer(measure_request_body_size_layer( 27 request_bytes_counter.clone(), 28 )) 29 .into_inner(), 30 ) 31 .add_service(svc) 32 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 33 .await 34 .unwrap(); 35 } 36 }); 37 38 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 39 .send_compressed(CompressionEncoding::Gzip); 40 41 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 42 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 43 let req = Request::new(Box::pin(stream)); 44 45 client.compress_input_client_stream(req).await.unwrap(); 46 47 let bytes_sent = request_bytes_counter.load(SeqCst); 48 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 49 } 50 51 #[tokio::test(flavor = "multi_thread")] 52 async fn client_disabled_server_enabled() { 53 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 54 55 let svc = 56 test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); 57 58 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 59 60 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> { 61 assert!(req.headers().get("grpc-encoding").is_none()); 62 req 63 } 64 65 tokio::spawn({ 66 let request_bytes_counter = request_bytes_counter.clone(); 67 async move { 68 Server::builder() 69 .layer( 70 ServiceBuilder::new() 71 .map_request(assert_right_encoding) 72 .layer(measure_request_body_size_layer( 73 request_bytes_counter.clone(), 74 )) 75 .into_inner(), 76 ) 77 .add_service(svc) 78 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 79 .await 80 .unwrap(); 81 } 82 }); 83 84 let mut client = test_client::TestClient::new(mock_io_channel(client).await); 85 86 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 87 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 88 let req = Request::new(Box::pin(stream)); 89 90 client.compress_input_client_stream(req).await.unwrap(); 91 92 let bytes_sent = request_bytes_counter.load(SeqCst); 93 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 94 } 95 96 #[tokio::test(flavor = "multi_thread")] 97 async fn client_enabled_server_disabled() { 98 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 99 100 let svc = test_server::TestServer::new(Svc::default()); 101 102 tokio::spawn(async move { 103 Server::builder() 104 .add_service(svc) 105 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 106 .await 107 .unwrap(); 108 }); 109 110 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 111 .send_compressed(CompressionEncoding::Gzip); 112 113 let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); 114 let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); 115 let req = Request::new(Box::pin(stream)); 116 117 let status = client.compress_input_client_stream(req).await.unwrap_err(); 118 119 assert_eq!(status.code(), tonic::Code::Unimplemented); 120 assert_eq!( 121 status.message(), 122 "Content is compressed with `gzip` which isn't supported" 123 ); 124 } 125 126 #[tokio::test(flavor = "multi_thread")] 127 async fn compressing_response_from_client_stream() { 128 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 129 130 let svc = 131 test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); 132 133 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 134 135 tokio::spawn({ 136 let response_bytes_counter = response_bytes_counter.clone(); 137 async move { 138 Server::builder() 139 .layer( 140 ServiceBuilder::new() 141 .layer(MapResponseBodyLayer::new(move |body| { 142 util::CountBytesBody { 143 inner: body, 144 counter: response_bytes_counter.clone(), 145 } 146 })) 147 .into_inner(), 148 ) 149 .add_service(svc) 150 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 151 .await 152 .unwrap(); 153 } 154 }); 155 156 let mut client = test_client::TestClient::new(mock_io_channel(client).await) 157 .accept_compressed(CompressionEncoding::Gzip); 158 159 let req = Request::new(Box::pin(tokio_stream::empty())); 160 161 let res = client.compress_output_client_stream(req).await.unwrap(); 162 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 163 let bytes_sent = response_bytes_counter.load(SeqCst); 164 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 165 } 166