1 use super::*; 2 3 #[tokio::test(flavor = "multi_thread")] 4 async fn client_enabled_server_enabled() { 5 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 6 7 #[derive(Clone, Copy)] 8 struct AssertCorrectAcceptEncoding<S>(S); 9 10 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S> 11 where 12 S: Service<http::Request<B>>, 13 { 14 type Response = S::Response; 15 type Error = S::Error; 16 type Future = S::Future; 17 18 fn poll_ready( 19 &mut self, 20 cx: &mut std::task::Context<'_>, 21 ) -> std::task::Poll<Result<(), Self::Error>> { 22 self.0.poll_ready(cx) 23 } 24 25 fn call(&mut self, req: http::Request<B>) -> Self::Future { 26 assert_eq!( 27 req.headers().get("grpc-accept-encoding").unwrap(), 28 "gzip,identity" 29 ); 30 self.0.call(req) 31 } 32 } 33 34 let svc = test_server::TestServer::new(Svc::default()).send_gzip(); 35 36 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 37 38 tokio::spawn({ 39 let response_bytes_counter = response_bytes_counter.clone(); 40 async move { 41 Server::builder() 42 .layer( 43 ServiceBuilder::new() 44 .layer(layer_fn(AssertCorrectAcceptEncoding)) 45 .layer(MapResponseBodyLayer::new(move |body| { 46 util::CountBytesBody { 47 inner: body, 48 counter: response_bytes_counter.clone(), 49 } 50 })) 51 .into_inner(), 52 ) 53 .add_service(svc) 54 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 55 MockStream(server), 56 )])) 57 .await 58 .unwrap(); 59 } 60 }); 61 62 let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); 63 64 for _ in 0..3 { 65 let res = client.compress_output_unary(()).await.unwrap(); 66 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 67 let bytes_sent = response_bytes_counter.load(SeqCst); 68 assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); 69 } 70 } 71 72 #[tokio::test(flavor = "multi_thread")] 73 async fn client_enabled_server_disabled() { 74 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 75 76 let svc = test_server::TestServer::new(Svc::default()); 77 78 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 79 80 tokio::spawn({ 81 let response_bytes_counter = response_bytes_counter.clone(); 82 async move { 83 Server::builder() 84 // no compression enable on the server so responses should not be compressed 85 .layer( 86 ServiceBuilder::new() 87 .layer(MapResponseBodyLayer::new(move |body| { 88 util::CountBytesBody { 89 inner: body, 90 counter: response_bytes_counter.clone(), 91 } 92 })) 93 .into_inner(), 94 ) 95 .add_service(svc) 96 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 97 MockStream(server), 98 )])) 99 .await 100 .unwrap(); 101 } 102 }); 103 104 let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); 105 106 let res = client.compress_output_unary(()).await.unwrap(); 107 108 assert!(res.metadata().get("grpc-encoding").is_none()); 109 110 let bytes_sent = response_bytes_counter.load(SeqCst); 111 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 112 } 113 114 #[tokio::test(flavor = "multi_thread")] 115 async fn client_disabled() { 116 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 117 118 #[derive(Clone, Copy)] 119 struct AssertCorrectAcceptEncoding<S>(S); 120 121 impl<S, B> Service<http::Request<B>> for AssertCorrectAcceptEncoding<S> 122 where 123 S: Service<http::Request<B>>, 124 { 125 type Response = S::Response; 126 type Error = S::Error; 127 type Future = S::Future; 128 129 fn poll_ready( 130 &mut self, 131 cx: &mut std::task::Context<'_>, 132 ) -> std::task::Poll<Result<(), Self::Error>> { 133 self.0.poll_ready(cx) 134 } 135 136 fn call(&mut self, req: http::Request<B>) -> Self::Future { 137 assert!(req.headers().get("grpc-accept-encoding").is_none()); 138 self.0.call(req) 139 } 140 } 141 142 let svc = test_server::TestServer::new(Svc::default()).send_gzip(); 143 144 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 145 146 tokio::spawn({ 147 let response_bytes_counter = response_bytes_counter.clone(); 148 async move { 149 Server::builder() 150 .layer( 151 ServiceBuilder::new() 152 .layer(layer_fn(AssertCorrectAcceptEncoding)) 153 .layer(MapResponseBodyLayer::new(move |body| { 154 util::CountBytesBody { 155 inner: body, 156 counter: response_bytes_counter.clone(), 157 } 158 })) 159 .into_inner(), 160 ) 161 .add_service(svc) 162 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 163 MockStream(server), 164 )])) 165 .await 166 .unwrap(); 167 } 168 }); 169 170 let mut client = test_client::TestClient::new(mock_io_channel(client).await); 171 172 let res = client.compress_output_unary(()).await.unwrap(); 173 174 assert!(res.metadata().get("grpc-encoding").is_none()); 175 176 let bytes_sent = response_bytes_counter.load(SeqCst); 177 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 178 } 179 180 #[tokio::test(flavor = "multi_thread")] 181 async fn server_replying_with_unsupported_encoding() { 182 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 183 184 let svc = test_server::TestServer::new(Svc::default()).send_gzip(); 185 186 fn add_weird_content_encoding<B>(mut response: http::Response<B>) -> http::Response<B> { 187 response 188 .headers_mut() 189 .insert("grpc-encoding", "br".parse().unwrap()); 190 response 191 } 192 193 tokio::spawn(async move { 194 Server::builder() 195 .layer( 196 ServiceBuilder::new() 197 .map_response(add_weird_content_encoding) 198 .into_inner(), 199 ) 200 .add_service(svc) 201 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 202 MockStream(server), 203 )])) 204 .await 205 .unwrap(); 206 }); 207 208 let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); 209 let status: Status = client.compress_output_unary(()).await.unwrap_err(); 210 211 assert_eq!(status.code(), tonic::Code::Unimplemented); 212 assert_eq!( 213 status.message(), 214 "Content is compressed with `br` which isn't supported" 215 ); 216 } 217 218 #[tokio::test(flavor = "multi_thread")] 219 async fn disabling_compression_on_single_response() { 220 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 221 222 let svc = test_server::TestServer::new(Svc { 223 disable_compressing_on_response: true, 224 }) 225 .send_gzip(); 226 227 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 228 229 tokio::spawn({ 230 let response_bytes_counter = response_bytes_counter.clone(); 231 async move { 232 Server::builder() 233 .layer( 234 ServiceBuilder::new() 235 .layer(MapResponseBodyLayer::new(move |body| { 236 util::CountBytesBody { 237 inner: body, 238 counter: response_bytes_counter.clone(), 239 } 240 })) 241 .into_inner(), 242 ) 243 .add_service(svc) 244 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 245 MockStream(server), 246 )])) 247 .await 248 .unwrap(); 249 } 250 }); 251 252 let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); 253 254 let res = client.compress_output_unary(()).await.unwrap(); 255 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 256 let bytes_sent = response_bytes_counter.load(SeqCst); 257 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 258 } 259 260 #[tokio::test(flavor = "multi_thread")] 261 async fn disabling_compression_on_response_but_keeping_compression_on_stream() { 262 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 263 264 let svc = test_server::TestServer::new(Svc { 265 disable_compressing_on_response: true, 266 }) 267 .send_gzip(); 268 269 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 270 271 tokio::spawn({ 272 let response_bytes_counter = response_bytes_counter.clone(); 273 async move { 274 Server::builder() 275 .layer( 276 ServiceBuilder::new() 277 .layer(MapResponseBodyLayer::new(move |body| { 278 util::CountBytesBody { 279 inner: body, 280 counter: response_bytes_counter.clone(), 281 } 282 })) 283 .into_inner(), 284 ) 285 .add_service(svc) 286 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 287 MockStream(server), 288 )])) 289 .await 290 .unwrap(); 291 } 292 }); 293 294 let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); 295 296 let res = client.compress_output_server_stream(()).await.unwrap(); 297 298 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 299 300 let mut stream: Streaming<SomeData> = res.into_inner(); 301 302 stream 303 .next() 304 .await 305 .expect("stream empty") 306 .expect("item was error"); 307 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 308 309 stream 310 .next() 311 .await 312 .expect("stream empty") 313 .expect("item was error"); 314 assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); 315 } 316 317 #[tokio::test(flavor = "multi_thread")] 318 async fn disabling_compression_on_response_from_client_stream() { 319 let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); 320 321 let svc = test_server::TestServer::new(Svc { 322 disable_compressing_on_response: true, 323 }) 324 .send_gzip(); 325 326 let response_bytes_counter = Arc::new(AtomicUsize::new(0)); 327 328 tokio::spawn({ 329 let response_bytes_counter = response_bytes_counter.clone(); 330 async move { 331 Server::builder() 332 .layer( 333 ServiceBuilder::new() 334 .layer(MapResponseBodyLayer::new(move |body| { 335 util::CountBytesBody { 336 inner: body, 337 counter: response_bytes_counter.clone(), 338 } 339 })) 340 .into_inner(), 341 ) 342 .add_service(svc) 343 .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( 344 MockStream(server), 345 )])) 346 .await 347 .unwrap(); 348 } 349 }); 350 351 let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); 352 353 let stream = futures::stream::iter(vec![]); 354 let req = Request::new(Box::pin(stream)); 355 356 let res = client.compress_output_client_stream(req).await.unwrap(); 357 assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); 358 let bytes_sent = response_bytes_counter.load(SeqCst); 359 assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); 360 } 361