xref: /tonic/tonic/src/client/grpc.rs (revision 79a06cc8)
1 use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
2 use crate::codec::EncodeBody;
3 use crate::metadata::GRPC_CONTENT_TYPE;
4 use crate::{
5     body::Body,
6     client::GrpcService,
7     codec::{Codec, Decoder, Streaming},
8     request::SanitizeHeaders,
9     Code, Request, Response, Status,
10 };
11 use http::{
12     header::{HeaderValue, CONTENT_TYPE, TE},
13     uri::{PathAndQuery, Uri},
14 };
15 use http_body::Body as HttpBody;
16 use std::{fmt, future, pin::pin};
17 use tokio_stream::{Stream, StreamExt};
18 
19 /// A gRPC client dispatcher.
20 ///
21 /// This will wrap some inner [`GrpcService`] and will encode/decode
22 /// messages via the provided codec.
23 ///
24 /// Each request method takes a [`Request`], a [`PathAndQuery`], and a
25 /// [`Codec`]. The request contains the message to send via the
26 /// [`Codec::encoder`]. The path determines the fully qualified path
27 /// that will be append to the outgoing uri. The path must follow
28 /// the conventions explained in the [gRPC protocol definition] under `Path →`. An
29 /// example of this path could look like `/greeter.Greeter/SayHello`.
30 ///
31 /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
32 pub struct Grpc<T> {
33     inner: T,
34     config: GrpcConfig,
35 }
36 
37 struct GrpcConfig {
38     origin: Uri,
39     /// Which compression encodings does the client accept?
40     accept_compression_encodings: EnabledCompressionEncodings,
41     /// The compression encoding that will be applied to requests.
42     send_compression_encodings: Option<CompressionEncoding>,
43     /// Limits the maximum size of a decoded message.
44     max_decoding_message_size: Option<usize>,
45     /// Limits the maximum size of an encoded message.
46     max_encoding_message_size: Option<usize>,
47 }
48 
49 impl<T> Grpc<T> {
50     /// Creates a new gRPC client with the provided [`GrpcService`].
new(inner: T) -> Self51     pub fn new(inner: T) -> Self {
52         Self::with_origin(inner, Uri::default())
53     }
54 
55     /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`.
56     ///
57     /// The provided Uri will use only the scheme and authority parts as the
58     /// path_and_query portion will be set for each method.
with_origin(inner: T, origin: Uri) -> Self59     pub fn with_origin(inner: T, origin: Uri) -> Self {
60         Self {
61             inner,
62             config: GrpcConfig {
63                 origin,
64                 send_compression_encodings: None,
65                 accept_compression_encodings: EnabledCompressionEncodings::default(),
66                 max_decoding_message_size: None,
67                 max_encoding_message_size: None,
68             },
69         }
70     }
71 
72     /// Compress requests with the provided encoding.
73     ///
74     /// Requires the server to accept the specified encoding, otherwise it might return an error.
75     ///
76     /// # Example
77     ///
78     /// The most common way of using this is through a client generated by tonic-build:
79     ///
80     /// ```rust
81     /// use tonic::transport::Channel;
82     /// # enum CompressionEncoding { Gzip }
83     /// # struct TestClient<T>(T);
84     /// # impl<T> TestClient<T> {
85     /// #     fn new(channel: T) -> Self { Self(channel) }
86     /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
87     /// # }
88     ///
89     /// # async {
90     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
91     ///     .connect()
92     ///     .await
93     ///     .unwrap();
94     ///
95     /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip);
96     /// # };
97     /// ```
send_compressed(mut self, encoding: CompressionEncoding) -> Self98     pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
99         self.config.send_compression_encodings = Some(encoding);
100         self
101     }
102 
103     /// Enable accepting compressed responses.
104     ///
105     /// Requires the server to also support sending compressed responses.
106     ///
107     /// # Example
108     ///
109     /// The most common way of using this is through a client generated by tonic-build:
110     ///
111     /// ```rust
112     /// use tonic::transport::Channel;
113     /// # enum CompressionEncoding { Gzip }
114     /// # struct TestClient<T>(T);
115     /// # impl<T> TestClient<T> {
116     /// #     fn new(channel: T) -> Self { Self(channel) }
117     /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
118     /// # }
119     ///
120     /// # async {
121     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
122     ///     .connect()
123     ///     .await
124     ///     .unwrap();
125     ///
126     /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip);
127     /// # };
128     /// ```
accept_compressed(mut self, encoding: CompressionEncoding) -> Self129     pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
130         self.config.accept_compression_encodings.enable(encoding);
131         self
132     }
133 
134     /// Limits the maximum size of a decoded message.
135     ///
136     /// # Example
137     ///
138     /// The most common way of using this is through a client generated by tonic-build:
139     ///
140     /// ```rust
141     /// use tonic::transport::Channel;
142     /// # struct TestClient<T>(T);
143     /// # impl<T> TestClient<T> {
144     /// #     fn new(channel: T) -> Self { Self(channel) }
145     /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
146     /// # }
147     ///
148     /// # async {
149     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
150     ///     .connect()
151     ///     .await
152     ///     .unwrap();
153     ///
154     /// // Set the limit to 2MB, Defaults to 4MB.
155     /// let limit = 2 * 1024 * 1024;
156     /// let client = TestClient::new(channel).max_decoding_message_size(limit);
157     /// # };
158     /// ```
max_decoding_message_size(mut self, limit: usize) -> Self159     pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
160         self.config.max_decoding_message_size = Some(limit);
161         self
162     }
163 
164     /// Limits the maximum size of an encoded message.
165     ///
166     /// # Example
167     ///
168     /// The most common way of using this is through a client generated by tonic-build:
169     ///
170     /// ```rust
171     /// use tonic::transport::Channel;
172     /// # struct TestClient<T>(T);
173     /// # impl<T> TestClient<T> {
174     /// #     fn new(channel: T) -> Self { Self(channel) }
175     /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
176     /// # }
177     ///
178     /// # async {
179     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
180     ///     .connect()
181     ///     .await
182     ///     .unwrap();
183     ///
184     /// // Set the limit to 2MB, Defaults to 4MB.
185     /// let limit = 2 * 1024 * 1024;
186     /// let client = TestClient::new(channel).max_encoding_message_size(limit);
187     /// # };
188     /// ```
max_encoding_message_size(mut self, limit: usize) -> Self189     pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
190         self.config.max_encoding_message_size = Some(limit);
191         self
192     }
193 
194     /// Check if the inner [`GrpcService`] is able to accept a  new request.
195     ///
196     /// This will call [`GrpcService::poll_ready`] until it returns ready or
197     /// an error. If this returns ready the inner [`GrpcService`] is ready to
198     /// accept one more request.
ready(&mut self) -> Result<(), T::Error> where T: GrpcService<Body>,199     pub async fn ready(&mut self) -> Result<(), T::Error>
200     where
201         T: GrpcService<Body>,
202     {
203         future::poll_fn(|cx| self.inner.poll_ready(cx)).await
204     }
205 
206     /// Send a single unary gRPC request.
unary<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,207     pub async fn unary<M1, M2, C>(
208         &mut self,
209         request: Request<M1>,
210         path: PathAndQuery,
211         codec: C,
212     ) -> Result<Response<M2>, Status>
213     where
214         T: GrpcService<Body>,
215         T::ResponseBody: HttpBody + Send + 'static,
216         <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
217         C: Codec<Encode = M1, Decode = M2>,
218         M1: Send + Sync + 'static,
219         M2: Send + Sync + 'static,
220     {
221         let request = request.map(|m| tokio_stream::once(m));
222         self.client_streaming(request, path, codec).await
223     }
224 
225     /// Send a client side streaming gRPC request.
client_streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,226     pub async fn client_streaming<S, M1, M2, C>(
227         &mut self,
228         request: Request<S>,
229         path: PathAndQuery,
230         codec: C,
231     ) -> Result<Response<M2>, Status>
232     where
233         T: GrpcService<Body>,
234         T::ResponseBody: HttpBody + Send + 'static,
235         <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
236         S: Stream<Item = M1> + Send + 'static,
237         C: Codec<Encode = M1, Decode = M2>,
238         M1: Send + Sync + 'static,
239         M2: Send + Sync + 'static,
240     {
241         let (mut parts, body, extensions) =
242             self.streaming(request, path, codec).await?.into_parts();
243 
244         let mut body = pin!(body);
245 
246         let message = body
247             .try_next()
248             .await
249             .map_err(|mut status| {
250                 status.metadata_mut().merge(parts.clone());
251                 status
252             })?
253             .ok_or_else(|| Status::internal("Missing response message."))?;
254 
255         if let Some(trailers) = body.trailers().await? {
256             parts.merge(trailers);
257         }
258 
259         Ok(Response::from_parts(parts, message, extensions))
260     }
261 
262     /// Send a server side streaming gRPC request.
server_streaming<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,263     pub async fn server_streaming<M1, M2, C>(
264         &mut self,
265         request: Request<M1>,
266         path: PathAndQuery,
267         codec: C,
268     ) -> Result<Response<Streaming<M2>>, Status>
269     where
270         T: GrpcService<Body>,
271         T::ResponseBody: HttpBody + Send + 'static,
272         <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
273         C: Codec<Encode = M1, Decode = M2>,
274         M1: Send + Sync + 'static,
275         M2: Send + Sync + 'static,
276     {
277         let request = request.map(|m| tokio_stream::once(m));
278         self.streaming(request, path, codec).await
279     }
280 
281     /// Send a bi-directional streaming gRPC request.
streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, mut codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,282     pub async fn streaming<S, M1, M2, C>(
283         &mut self,
284         request: Request<S>,
285         path: PathAndQuery,
286         mut codec: C,
287     ) -> Result<Response<Streaming<M2>>, Status>
288     where
289         T: GrpcService<Body>,
290         T::ResponseBody: HttpBody + Send + 'static,
291         <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
292         S: Stream<Item = M1> + Send + 'static,
293         C: Codec<Encode = M1, Decode = M2>,
294         M1: Send + Sync + 'static,
295         M2: Send + Sync + 'static,
296     {
297         let request = request
298             .map(|s| {
299                 EncodeBody::new_client(
300                     codec.encoder(),
301                     s.map(Ok),
302                     self.config.send_compression_encodings,
303                     self.config.max_encoding_message_size,
304                 )
305             })
306             .map(Body::new);
307 
308         let request = self.config.prepare_request(request, path);
309 
310         let response = self
311             .inner
312             .call(request)
313             .await
314             .map_err(Status::from_error_generic)?;
315 
316         let decoder = codec.decoder();
317 
318         self.create_response(decoder, response)
319     }
320 
321     // Keeping this code in a separate function from Self::streaming lets functions that return the
322     // same output share the generated binary code
create_response<M2>( &self, decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, response: http::Response<T::ResponseBody>, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<Body>, T::ResponseBody: HttpBody + Send + 'static, <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,323     fn create_response<M2>(
324         &self,
325         decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static,
326         response: http::Response<T::ResponseBody>,
327     ) -> Result<Response<Streaming<M2>>, Status>
328     where
329         T: GrpcService<Body>,
330         T::ResponseBody: HttpBody + Send + 'static,
331         <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
332     {
333         let encoding = CompressionEncoding::from_encoding_header(
334             response.headers(),
335             self.config.accept_compression_encodings,
336         )?;
337 
338         let status_code = response.status();
339         let trailers_only_status = Status::from_header_map(response.headers());
340 
341         // We do not need to check for trailers if the `grpc-status` header is present
342         // with a valid code.
343         let expect_additional_trailers = if let Some(status) = trailers_only_status {
344             if status.code() != Code::Ok {
345                 return Err(status);
346             }
347 
348             false
349         } else {
350             true
351         };
352 
353         let response = response.map(|body| {
354             if expect_additional_trailers {
355                 Streaming::new_response(
356                     decoder,
357                     body,
358                     status_code,
359                     encoding,
360                     self.config.max_decoding_message_size,
361                 )
362             } else {
363                 Streaming::new_empty(decoder, body)
364             }
365         });
366 
367         Ok(Response::from_http(response))
368     }
369 }
370 
371 impl GrpcConfig {
prepare_request(&self, request: Request<Body>, path: PathAndQuery) -> http::Request<Body>372     fn prepare_request(&self, request: Request<Body>, path: PathAndQuery) -> http::Request<Body> {
373         let mut parts = self.origin.clone().into_parts();
374 
375         match &parts.path_and_query {
376             Some(pnq) if pnq != "/" => {
377                 parts.path_and_query = Some(
378                     format!("{}{}", pnq.path(), path)
379                         .parse()
380                         .expect("must form valid path_and_query"),
381                 )
382             }
383             _ => {
384                 parts.path_and_query = Some(path);
385             }
386         }
387 
388         let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
389 
390         let mut request = request.into_http(
391             uri,
392             http::Method::POST,
393             http::Version::HTTP_2,
394             SanitizeHeaders::Yes,
395         );
396 
397         // Add the gRPC related HTTP headers
398         request
399             .headers_mut()
400             .insert(TE, HeaderValue::from_static("trailers"));
401 
402         // Set the content type
403         request
404             .headers_mut()
405             .insert(CONTENT_TYPE, GRPC_CONTENT_TYPE);
406 
407         #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
408         if let Some(encoding) = self.send_compression_encodings {
409             request.headers_mut().insert(
410                 crate::codec::compression::ENCODING_HEADER,
411                 encoding.into_header_value(),
412             );
413         }
414 
415         if let Some(header_value) = self
416             .accept_compression_encodings
417             .into_accept_encoding_header_value()
418         {
419             request.headers_mut().insert(
420                 crate::codec::compression::ACCEPT_ENCODING_HEADER,
421                 header_value,
422             );
423         }
424 
425         request
426     }
427 }
428 
429 impl<T: Clone> Clone for Grpc<T> {
clone(&self) -> Self430     fn clone(&self) -> Self {
431         Self {
432             inner: self.inner.clone(),
433             config: GrpcConfig {
434                 origin: self.config.origin.clone(),
435                 send_compression_encodings: self.config.send_compression_encodings,
436                 accept_compression_encodings: self.config.accept_compression_encodings,
437                 max_encoding_message_size: self.config.max_encoding_message_size,
438                 max_decoding_message_size: self.config.max_decoding_message_size,
439             },
440         }
441     }
442 }
443 
444 impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result445     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446         f.debug_struct("Grpc")
447             .field("inner", &self.inner)
448             .field("origin", &self.config.origin)
449             .field(
450                 "compression_encoding",
451                 &self.config.send_compression_encodings,
452             )
453             .field(
454                 "accept_compression_encodings",
455                 &self.config.accept_compression_encodings,
456             )
457             .field(
458                 "max_decoding_message_size",
459                 &self.config.max_decoding_message_size,
460             )
461             .field(
462                 "max_encoding_message_size",
463                 &self.config.max_encoding_message_size,
464             )
465             .finish()
466     }
467 }
468