1 use super::*; 2 use http_body::Body; 3 use tonic::codec::CompressionEncoding; 4 5 util::parametrized_tests! { 6 client_enabled_server_enabled, 7 zstd: CompressionEncoding::Zstd, 8 gzip: CompressionEncoding::Gzip, 9 deflate: CompressionEncoding::Deflate, 10 } 11 12 #[allow(dead_code)] 13 async fn client_enabled_server_enabled(encoding: CompressionEncoding) { 14 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 15 16 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); 17 18 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 19 20 #[derive(Clone)] 21 pub struct AssertRightEncoding { 22 encoding: CompressionEncoding, 23 } 24 25 #[allow(dead_code)] 26 impl AssertRightEncoding { 27 pub fn new(encoding: CompressionEncoding) -> Self { 28 Self { encoding } 29 } 30 31 pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> { 32 let expected = match self.encoding { 33 CompressionEncoding::Gzip => "gzip", 34 CompressionEncoding::Zstd => "zstd", 35 CompressionEncoding::Deflate => "deflate", 36 _ => panic!("unexpected encoding {:?}", self.encoding), 37 }; 38 assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); 39 40 req 41 } 42 } 43 44 tokio::spawn({ 45 let request_bytes_counter = request_bytes_counter.clone(); 46 async move { 47 Server::builder() 48 .layer( 49 ServiceBuilder::new() 50 .layer( 51 ServiceBuilder::new() 52 .map_request(move |req| { 53 AssertRightEncoding::new(encoding).clone().call(req) 54 }) 55 .layer(measure_request_body_size_layer(request_bytes_counter)) 56 .into_inner(), 57 ) 58 .into_inner(), 59 ) 60 .add_service(svc) 61 .serve_with_incoming(tokio_stream::iter(vec![Ok::<_, std::io::Error>(server)])) 62 .await 63 .unwrap(); 64 } 65 }); 66 67 let mut client = 68 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 69 70 for _ in 0..3 { 71 client 72 .compress_input_unary(SomeData { 73 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 74 }) 75 .await 76 .unwrap(); 77 let bytes_sent = request_bytes_counter.load(SeqCst); 78 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 79 } 80 } 81 82 util::parametrized_tests! { 83 client_enabled_server_enabled_multi_encoding, 84 zstd: CompressionEncoding::Zstd, 85 gzip: CompressionEncoding::Gzip, 86 deflate: CompressionEncoding::Deflate, 87 } 88 89 #[allow(dead_code)] 90 async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) { 91 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 92 93 let svc = test_server::TestServer::new(Svc::default()) 94 .accept_compressed(CompressionEncoding::Gzip) 95 .accept_compressed(CompressionEncoding::Zstd) 96 .accept_compressed(CompressionEncoding::Deflate); 97 98 let request_bytes_counter = Arc::new(AtomicUsize::new(0)); 99 100 fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> { 101 let supported_encodings = ["gzip", "zstd", "deflate"]; 102 let req_encoding = req.headers().get("grpc-encoding").unwrap(); 103 assert!(supported_encodings.iter().any(|e| e == req_encoding)); 104 105 req 106 } 107 108 tokio::spawn({ 109 let request_bytes_counter = request_bytes_counter.clone(); 110 async move { 111 Server::builder() 112 .layer( 113 ServiceBuilder::new() 114 .layer( 115 ServiceBuilder::new() 116 .map_request(assert_right_encoding) 117 .layer(measure_request_body_size_layer(request_bytes_counter)) 118 .into_inner(), 119 ) 120 .into_inner(), 121 ) 122 .add_service(svc) 123 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 124 .await 125 .unwrap(); 126 } 127 }); 128 129 let mut client = 130 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 131 132 for _ in 0..3 { 133 client 134 .compress_input_unary(SomeData { 135 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 136 }) 137 .await 138 .unwrap(); 139 let bytes_sent = request_bytes_counter.load(SeqCst); 140 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 141 } 142 } 143 144 parametrized_tests! { 145 client_enabled_server_disabled, 146 zstd: CompressionEncoding::Zstd, 147 gzip: CompressionEncoding::Gzip, 148 deflate: CompressionEncoding::Deflate, 149 } 150 151 #[allow(dead_code)] 152 async fn client_enabled_server_disabled(encoding: CompressionEncoding) { 153 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 154 155 let svc = test_server::TestServer::new(Svc::default()); 156 157 tokio::spawn(async move { 158 Server::builder() 159 .add_service(svc) 160 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 161 .await 162 .unwrap(); 163 }); 164 165 let mut client = 166 test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); 167 168 let status = client 169 .compress_input_unary(SomeData { 170 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 171 }) 172 .await 173 .unwrap_err(); 174 175 assert_eq!(status.code(), tonic::Code::Unimplemented); 176 let expected = match encoding { 177 CompressionEncoding::Gzip => "gzip", 178 CompressionEncoding::Zstd => "zstd", 179 CompressionEncoding::Deflate => "deflate", 180 _ => panic!("unexpected encoding {:?}", encoding), 181 }; 182 assert_eq!( 183 status.message(), 184 format!( 185 "Content is compressed with `{}` which isn't supported", 186 expected 187 ) 188 ); 189 190 assert_eq!( 191 status.metadata().get("grpc-accept-encoding").unwrap(), 192 "identity" 193 ); 194 } 195 parametrized_tests! { 196 client_mark_compressed_without_header_server_enabled, 197 zstd: CompressionEncoding::Zstd, 198 gzip: CompressionEncoding::Gzip, 199 deflate: CompressionEncoding::Deflate, 200 } 201 202 #[allow(dead_code)] 203 async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) { 204 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 205 206 let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); 207 208 tokio::spawn({ 209 async move { 210 Server::builder() 211 .add_service(svc) 212 .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) 213 .await 214 .unwrap(); 215 } 216 }); 217 218 let mut client = test_client::TestClient::with_interceptor( 219 mock_io_channel(client).await, 220 move |mut req: Request<()>| { 221 req.metadata_mut().remove("grpc-encoding"); 222 Ok(req) 223 }, 224 ) 225 .send_compressed(CompressionEncoding::Gzip); 226 227 let status = client 228 .compress_input_unary(SomeData { 229 data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), 230 }) 231 .await 232 .unwrap_err(); 233 234 assert_eq!(status.code(), tonic::Code::Internal); 235 assert_eq!( 236 status.message(), 237 "protocol error: received message with compressed-flag but no grpc-encoding was specified" 238 ); 239 } 240