1 use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; 2 use crate::codec::EncodeBody; 3 use crate::metadata::GRPC_CONTENT_TYPE; 4 use crate::{ 5 body::Body, 6 client::GrpcService, 7 codec::{Codec, Decoder, Streaming}, 8 request::SanitizeHeaders, 9 Code, Request, Response, Status, 10 }; 11 use http::{ 12 header::{HeaderValue, CONTENT_TYPE, TE}, 13 uri::{PathAndQuery, Uri}, 14 }; 15 use http_body::Body as HttpBody; 16 use std::{fmt, future, pin::pin}; 17 use tokio_stream::{Stream, StreamExt}; 18 19 /// A gRPC client dispatcher. 20 /// 21 /// This will wrap some inner [`GrpcService`] and will encode/decode 22 /// messages via the provided codec. 23 /// 24 /// Each request method takes a [`Request`], a [`PathAndQuery`], and a 25 /// [`Codec`]. The request contains the message to send via the 26 /// [`Codec::encoder`]. The path determines the fully qualified path 27 /// that will be append to the outgoing uri. The path must follow 28 /// the conventions explained in the [gRPC protocol definition] under `Path →`. An 29 /// example of this path could look like `/greeter.Greeter/SayHello`. 30 /// 31 /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests 32 pub struct Grpc<T> { 33 inner: T, 34 config: GrpcConfig, 35 } 36 37 struct GrpcConfig { 38 origin: Uri, 39 /// Which compression encodings does the client accept? 40 accept_compression_encodings: EnabledCompressionEncodings, 41 /// The compression encoding that will be applied to requests. 42 send_compression_encodings: Option<CompressionEncoding>, 43 /// Limits the maximum size of a decoded message. 44 max_decoding_message_size: Option<usize>, 45 /// Limits the maximum size of an encoded message. 46 max_encoding_message_size: Option<usize>, 47 } 48 49 impl<T> Grpc<T> { 50 /// Creates a new gRPC client with the provided [`GrpcService`]. new(inner: T) -> Self51 pub fn new(inner: T) -> Self { 52 Self::with_origin(inner, Uri::default()) 53 } 54 55 /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`. 56 /// 57 /// The provided Uri will use only the scheme and authority parts as the 58 /// path_and_query portion will be set for each method. with_origin(inner: T, origin: Uri) -> Self59 pub fn with_origin(inner: T, origin: Uri) -> Self { 60 Self { 61 inner, 62 config: GrpcConfig { 63 origin, 64 send_compression_encodings: None, 65 accept_compression_encodings: EnabledCompressionEncodings::default(), 66 max_decoding_message_size: None, 67 max_encoding_message_size: None, 68 }, 69 } 70 } 71 72 /// Compress requests with the provided encoding. 73 /// 74 /// Requires the server to accept the specified encoding, otherwise it might return an error. 75 /// 76 /// # Example 77 /// 78 /// The most common way of using this is through a client generated by tonic-build: 79 /// 80 /// ```rust 81 /// use tonic::transport::Channel; 82 /// # enum CompressionEncoding { Gzip } 83 /// # struct TestClient<T>(T); 84 /// # impl<T> TestClient<T> { 85 /// # fn new(channel: T) -> Self { Self(channel) } 86 /// # fn send_compressed(self, _: CompressionEncoding) -> Self { self } 87 /// # } 88 /// 89 /// # async { 90 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 91 /// .connect() 92 /// .await 93 /// .unwrap(); 94 /// 95 /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip); 96 /// # }; 97 /// ``` send_compressed(mut self, encoding: CompressionEncoding) -> Self98 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { 99 self.config.send_compression_encodings = Some(encoding); 100 self 101 } 102 103 /// Enable accepting compressed responses. 104 /// 105 /// Requires the server to also support sending compressed responses. 106 /// 107 /// # Example 108 /// 109 /// The most common way of using this is through a client generated by tonic-build: 110 /// 111 /// ```rust 112 /// use tonic::transport::Channel; 113 /// # enum CompressionEncoding { Gzip } 114 /// # struct TestClient<T>(T); 115 /// # impl<T> TestClient<T> { 116 /// # fn new(channel: T) -> Self { Self(channel) } 117 /// # fn accept_compressed(self, _: CompressionEncoding) -> Self { self } 118 /// # } 119 /// 120 /// # async { 121 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 122 /// .connect() 123 /// .await 124 /// .unwrap(); 125 /// 126 /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip); 127 /// # }; 128 /// ``` accept_compressed(mut self, encoding: CompressionEncoding) -> Self129 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { 130 self.config.accept_compression_encodings.enable(encoding); 131 self 132 } 133 134 /// Limits the maximum size of a decoded message. 135 /// 136 /// # Example 137 /// 138 /// The most common way of using this is through a client generated by tonic-build: 139 /// 140 /// ```rust 141 /// use tonic::transport::Channel; 142 /// # struct TestClient<T>(T); 143 /// # impl<T> TestClient<T> { 144 /// # fn new(channel: T) -> Self { Self(channel) } 145 /// # fn max_decoding_message_size(self, _: usize) -> Self { self } 146 /// # } 147 /// 148 /// # async { 149 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 150 /// .connect() 151 /// .await 152 /// .unwrap(); 153 /// 154 /// // Set the limit to 2MB, Defaults to 4MB. 155 /// let limit = 2 * 1024 * 1024; 156 /// let client = TestClient::new(channel).max_decoding_message_size(limit); 157 /// # }; 158 /// ``` max_decoding_message_size(mut self, limit: usize) -> Self159 pub fn max_decoding_message_size(mut self, limit: usize) -> Self { 160 self.config.max_decoding_message_size = Some(limit); 161 self 162 } 163 164 /// Limits the maximum size of an encoded message. 165 /// 166 /// # Example 167 /// 168 /// The most common way of using this is through a client generated by tonic-build: 169 /// 170 /// ```rust 171 /// use tonic::transport::Channel; 172 /// # struct TestClient<T>(T); 173 /// # impl<T> TestClient<T> { 174 /// # fn new(channel: T) -> Self { Self(channel) } 175 /// # fn max_encoding_message_size(self, _: usize) -> Self { self } 176 /// # } 177 /// 178 /// # async { 179 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 180 /// .connect() 181 /// .await 182 /// .unwrap(); 183 /// 184 /// // Set the limit to 2MB, Defaults to 4MB. 185 /// let limit = 2 * 1024 * 1024; 186 /// let client = TestClient::new(channel).max_encoding_message_size(limit); 187 /// # }; 188 /// ``` max_encoding_message_size(mut self, limit: usize) -> Self189 pub fn max_encoding_message_size(mut self, limit: usize) -> Self { 190 self.config.max_encoding_message_size = Some(limit); 191 self 192 } 193 194 /// Check if the inner [`GrpcService`] is able to accept a new request. 195 /// 196 /// This will call [`GrpcService::poll_ready`] until it returns ready or 197 /// an error. If this returns ready the inner [`GrpcService`] is ready to 198 /// accept one more request. ready(&mut self) -> Result<(), T::Error> where T: GrpcService<Body>,199 pub async fn ready(&mut self) -> Result<(), T::Error> 200 where 201 T: GrpcService<Body>, 202 { 203 future::poll_fn(|cx| self.inner.poll_ready(cx)).await 204 } 205 206 /// Send a single unary gRPC request. unary<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,207 pub async fn unary<M1, M2, C>( 208 &mut self, 209 request: Request<M1>, 210 path: PathAndQuery, 211 codec: C, 212 ) -> Result<Response<M2>, Status> 213 where 214 T: GrpcService<Body>, 215 T::ResponseBody: HttpBody + Send + 'static, 216 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, 217 C: Codec<Encode = M1, Decode = M2>, 218 M1: Send + Sync + 'static, 219 M2: Send + Sync + 'static, 220 { 221 let request = request.map(|m| tokio_stream::once(m)); 222 self.client_streaming(request, path, codec).await 223 } 224 225 /// Send a client side streaming gRPC request. client_streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,226 pub async fn client_streaming<S, M1, M2, C>( 227 &mut self, 228 request: Request<S>, 229 path: PathAndQuery, 230 codec: C, 231 ) -> Result<Response<M2>, Status> 232 where 233 T: GrpcService<Body>, 234 T::ResponseBody: HttpBody + Send + 'static, 235 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, 236 S: Stream<Item = M1> + Send + 'static, 237 C: Codec<Encode = M1, Decode = M2>, 238 M1: Send + Sync + 'static, 239 M2: Send + Sync + 'static, 240 { 241 let (mut parts, body, extensions) = 242 self.streaming(request, path, codec).await?.into_parts(); 243 244 let mut body = pin!(body); 245 246 let message = body 247 .try_next() 248 .await 249 .map_err(|mut status| { 250 status.metadata_mut().merge(parts.clone()); 251 status 252 })? 253 .ok_or_else(|| Status::internal("Missing response message."))?; 254 255 if let Some(trailers) = body.trailers().await? { 256 parts.merge(trailers); 257 } 258 259 Ok(Response::from_parts(parts, message, extensions)) 260 } 261 262 /// Send a server side streaming gRPC request. server_streaming<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,263 pub async fn server_streaming<M1, M2, C>( 264 &mut self, 265 request: Request<M1>, 266 path: PathAndQuery, 267 codec: C, 268 ) -> Result<Response<Streaming<M2>>, Status> 269 where 270 T: GrpcService<Body>, 271 T::ResponseBody: HttpBody + Send + 'static, 272 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, 273 C: Codec<Encode = M1, Decode = M2>, 274 M1: Send + Sync + 'static, 275 M2: Send + Sync + 'static, 276 { 277 let request = request.map(|m| tokio_stream::once(m)); 278 self.streaming(request, path, codec).await 279 } 280 281 /// Send a bi-directional streaming gRPC request. streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, mut codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,282 pub async fn streaming<S, M1, M2, C>( 283 &mut self, 284 request: Request<S>, 285 path: PathAndQuery, 286 mut codec: C, 287 ) -> Result<Response<Streaming<M2>>, Status> 288 where 289 T: GrpcService<Body>, 290 T::ResponseBody: HttpBody + Send + 'static, 291 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, 292 S: Stream<Item = M1> + Send + 'static, 293 C: Codec<Encode = M1, Decode = M2>, 294 M1: Send + Sync + 'static, 295 M2: Send + Sync + 'static, 296 { 297 let request = request 298 .map(|s| { 299 EncodeBody::new_client( 300 codec.encoder(), 301 s.map(Ok), 302 self.config.send_compression_encodings, 303 self.config.max_encoding_message_size, 304 ) 305 }) 306 .map(Body::new); 307 308 let request = self.config.prepare_request(request, path); 309 310 let response = self 311 .inner 312 .call(request) 313 .await 314 .map_err(Status::from_error_generic)?; 315 316 let decoder = codec.decoder(); 317 318 self.create_response(decoder, response) 319 } 320 321 // Keeping this code in a separate function from Self::streaming lets functions that return the 322 // same output share the generated binary code create_response<M2>( &self, decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, response: http::Response<T::ResponseBody>, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,323 fn create_response<M2>( 324 &self, 325 decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, 326 response: http::Response<T::ResponseBody>, 327 ) -> Result<Response<Streaming<M2>>, Status> 328 where 329 T: GrpcService<Body>, 330 T::ResponseBody: HttpBody + Send + 'static, 331 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, 332 { 333 let encoding = CompressionEncoding::from_encoding_header( 334 response.headers(), 335 self.config.accept_compression_encodings, 336 )?; 337 338 let status_code = response.status(); 339 let trailers_only_status = Status::from_header_map(response.headers()); 340 341 // We do not need to check for trailers if the `grpc-status` header is present 342 // with a valid code. 343 let expect_additional_trailers = if let Some(status) = trailers_only_status { 344 if status.code() != Code::Ok { 345 return Err(status); 346 } 347 348 false 349 } else { 350 true 351 }; 352 353 let response = response.map(|body| { 354 if expect_additional_trailers { 355 Streaming::new_response( 356 decoder, 357 body, 358 status_code, 359 encoding, 360 self.config.max_decoding_message_size, 361 ) 362 } else { 363 Streaming::new_empty(decoder, body) 364 } 365 }); 366 367 Ok(Response::from_http(response)) 368 } 369 } 370 371 impl GrpcConfig { prepare_request(&self, request: Request<Body>, path: PathAndQuery) -> http::Request<Body>372 fn prepare_request(&self, request: Request<Body>, path: PathAndQuery) -> http::Request<Body> { 373 let mut parts = self.origin.clone().into_parts(); 374 375 match &parts.path_and_query { 376 Some(pnq) if pnq != "/" => { 377 parts.path_and_query = Some( 378 format!("{}{}", pnq.path(), path) 379 .parse() 380 .expect("must form valid path_and_query"), 381 ) 382 } 383 _ => { 384 parts.path_and_query = Some(path); 385 } 386 } 387 388 let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); 389 390 let mut request = request.into_http( 391 uri, 392 http::Method::POST, 393 http::Version::HTTP_2, 394 SanitizeHeaders::Yes, 395 ); 396 397 // Add the gRPC related HTTP headers 398 request 399 .headers_mut() 400 .insert(TE, HeaderValue::from_static("trailers")); 401 402 // Set the content type 403 request 404 .headers_mut() 405 .insert(CONTENT_TYPE, GRPC_CONTENT_TYPE); 406 407 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] 408 if let Some(encoding) = self.send_compression_encodings { 409 request.headers_mut().insert( 410 crate::codec::compression::ENCODING_HEADER, 411 encoding.into_header_value(), 412 ); 413 } 414 415 if let Some(header_value) = self 416 .accept_compression_encodings 417 .into_accept_encoding_header_value() 418 { 419 request.headers_mut().insert( 420 crate::codec::compression::ACCEPT_ENCODING_HEADER, 421 header_value, 422 ); 423 } 424 425 request 426 } 427 } 428 429 impl<T: Clone> Clone for Grpc<T> { clone(&self) -> Self430 fn clone(&self) -> Self { 431 Self { 432 inner: self.inner.clone(), 433 config: GrpcConfig { 434 origin: self.config.origin.clone(), 435 send_compression_encodings: self.config.send_compression_encodings, 436 accept_compression_encodings: self.config.accept_compression_encodings, 437 max_encoding_message_size: self.config.max_encoding_message_size, 438 max_decoding_message_size: self.config.max_decoding_message_size, 439 }, 440 } 441 } 442 } 443 444 impl<T: fmt::Debug> fmt::Debug for Grpc<T> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 446 f.debug_struct("Grpc") 447 .field("inner", &self.inner) 448 .field("origin", &self.config.origin) 449 .field( 450 "compression_encoding", 451 &self.config.send_compression_encodings, 452 ) 453 .field( 454 "accept_compression_encodings", 455 &self.config.accept_compression_encodings, 456 ) 457 .field( 458 "max_decoding_message_size", 459 &self.config.max_decoding_message_size, 460 ) 461 .field( 462 "max_encoding_message_size", 463 &self.config.max_encoding_message_size, 464 ) 465 .finish() 466 } 467 } 468