1 use super::*; 2 use http_body::Body; 3 use tonic::codec::CompressionEncoding; 4 5 util::parametrized_tests! { 6 client_enabled_server_enabled, 7 zstd: CompressionEncoding::Zstd, 8 gzip: CompressionEncoding::Gzip, 9 } 10 11 #[allow(dead_code)] 12 async fn client_enabled_server_enabled(encoding: CompressionEncoding) { 13 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 14 15 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); 16 17 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 18 19 #[derive(Clone)] 20 pub struct AssertRightEncoding { 21 encoding: CompressionEncoding, 22 } 23 24 #[allow(dead_code)] 25 impl AssertRightEncoding { 26 pub fn new(encoding: CompressionEncoding) -> Self { 27 Self { encoding } 28 } 29 30 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> { 31 let expected = match self.encoding { 32 CompressionEncoding::Gzip => "gzip", 33 CompressionEncoding::Zstd => "zstd", 34 _ => panic!("unexpected encoding {:?}", self.encoding), 35 }; 36 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); 37 38 req 39 } 40 } 41 42 tokio::spawn({ 43 let request_bytes_counter = request_bytes_counter.clone(); 44 async move { 45 Server::builder() 46 .layer( 47 ServiceBuilder::new() 48 .layer( 49 ServiceBuilder::new() 50 .map_request(move |req| { 51 AssertRightEncoding::new(encoding).clone().call(req) 52 }) 53 .layer(measure_request_body_size_layer(request_bytes_counter)) 54 .into_inner(), 55 ) 56 .into_inner(), 57 ) 58 .add_service(svc) 59 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)])) 60 .await 61 .unwrap(); 62 } 63 }); 64 65 let mut client = 66 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 67 68 for _ in 0..3 { 69 client 70 .compress_input_unary(SomeData { 71 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 72 }) 73 .await 74 .unwrap(); 75 let bytes_sent = request_bytes_counter.load(SeqCst); 76 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 77 } 78 } 79 80 util::parametrized_tests! { 81 client_enabled_server_enabled_multi_encoding, 82 zstd: CompressionEncoding::Zstd, 83 gzip: CompressionEncoding::Gzip, 84 } 85 86 #[allow(dead_code)] 87 async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) { 88 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 89 90 let svc = test_server::TestServer::new(Svc::default()) 91 .accept_compressed(CompressionEncoding::Gzip) 92 .accept_compressed(CompressionEncoding::Zstd); 93 94 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 95 96 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> { 97 let supported_encodings = ["gzip", "zstd"]; 98 let req_encoding = req.headers().get("grpc-encoding").unwrap(); 99 assert!(supported_encodings.iter().any(|e| e == req_encoding)); 100 101 req 102 } 103 104 tokio::spawn({ 105 let request_bytes_counter = request_bytes_counter.clone(); 106 async move { 107 Server::builder() 108 .layer( 109 ServiceBuilder::new() 110 .layer( 111 ServiceBuilder::new() 112 .map_request(assert_right_encoding) 113 .layer(measure_request_body_size_layer(request_bytes_counter)) 114 .into_inner(), 115 ) 116 .into_inner(), 117 ) 118 .add_service(svc) 119 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 120 .await 121 .unwrap(); 122 } 123 }); 124 125 let mut client = 126 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 127 128 for _ in 0..3 { 129 client 130 .compress_input_unary(SomeData { 131 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 132 }) 133 .await 134 .unwrap(); 135 let bytes_sent = request_bytes_counter.load(SeqCst); 136 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 137 } 138 } 139 140 parametrized_tests! { 141 client_enabled_server_disabled, 142 zstd: CompressionEncoding::Zstd, 143 gzip: CompressionEncoding::Gzip, 144 } 145 146 #[allow(dead_code)] 147 async fn client_enabled_server_disabled(encoding: CompressionEncoding) { 148 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 149 150 let svc = test_server::TestServer::new(Svc::default()); 151 152 tokio::spawn(async move { 153 Server::builder() 154 .add_service(svc) 155 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 156 .await 157 .unwrap(); 158 }); 159 160 let mut client = 161 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 162 163 let status = client 164 .compress_input_unary(SomeData { 165 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 166 }) 167 .await 168 .unwrap_err(); 169 170 assert_eq!(status.code(), tonic::Code::Unimplemented); 171 let expected = match encoding { 172 CompressionEncoding::Gzip => "gzip", 173 CompressionEncoding::Zstd => "zstd", 174 _ => panic!("unexpected encoding {:?}", encoding), 175 }; 176 assert_eq!( 177 status.message(), 178 format!( 179 "Content is compressed with `{}` which isn't supported", 180 expected 181 ) 182 ); 183 184 assert_eq!( 185 status.metadata().get("grpc-accept-encoding").unwrap(), 186 "identity" 187 ); 188 } 189 parametrized_tests! { 190 client_mark_compressed_without_header_server_enabled, 191 zstd: CompressionEncoding::Zstd, 192 gzip: CompressionEncoding::Gzip, 193 } 194 195 #[allow(dead_code)] 196 async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) { 197 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 198 199 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); 200 201 tokio::spawn({ 202 async move { 203 Server::builder() 204 .add_service(svc) 205 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 206 .await 207 .unwrap(); 208 } 209 }); 210 211 let mut client = test_client::TestClient::with_interceptor( 212 mock_io_channel(client).await, 213 move |mut req: Request<()>| { 214 req.metadata_mut().remove("grpc-encoding"); 215 Ok(req) 216 }, 217 ) 218 .send_compressed(CompressionEncoding::Gzip); 219 220 let status = client 221 .compress_input_unary(SomeData { 222 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 223 }) 224 .await 225 .unwrap_err(); 226 227 assert_eq!(status.code(), tonic::Code::Internal); 228 assert_eq!( 229 status.message(), 230 "protocol error: received message with compressed-flag but no grpc-encoding was specified" 231 ); 232 } 233