xref: /tonic/tonic/src/server/grpc.rs (revision 79a06cc8)
1 use crate::codec::compression::{
2     CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3 };
4 use crate::codec::EncodeBody;
5 use crate::metadata::GRPC_CONTENT_TYPE;
6 use crate::{
7     body::Body,
8     codec::{Codec, Streaming},
9     server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
10     Request, Status,
11 };
12 use http_body::Body as HttpBody;
13 use std::{fmt, pin::pin};
14 use tokio_stream::{Stream, StreamExt};
15 
16 macro_rules! t {
17     ($result:expr) => {
18         match $result {
19             Ok(value) => value,
20             Err(status) => return status.into_http(),
21         }
22     };
23 }
24 
25 /// A gRPC Server handler.
26 ///
27 /// This will wrap some inner [`Codec`] and provide utilities to handle
28 /// inbound unary, client side streaming, server side streaming, and
29 /// bi-directional streaming.
30 ///
31 /// Each request handler method accepts some service that implements the
32 /// corresponding service trait and a http request that contains some body that
33 /// implements some [`Body`].
34 pub struct Grpc<T> {
35     codec: T,
36     /// Which compression encodings does the server accept for requests?
37     accept_compression_encodings: EnabledCompressionEncodings,
38     /// Which compression encodings might the server use for responses.
39     send_compression_encodings: EnabledCompressionEncodings,
40     /// Limits the maximum size of a decoded message.
41     max_decoding_message_size: Option<usize>,
42     /// Limits the maximum size of an encoded message.
43     max_encoding_message_size: Option<usize>,
44 }
45 
46 impl<T> Grpc<T>
47 where
48     T: Codec,
49 {
50     /// Creates a new gRPC server with the provided [`Codec`].
new(codec: T) -> Self51     pub fn new(codec: T) -> Self {
52         Self {
53             codec,
54             accept_compression_encodings: EnabledCompressionEncodings::default(),
55             send_compression_encodings: EnabledCompressionEncodings::default(),
56             max_decoding_message_size: None,
57             max_encoding_message_size: None,
58         }
59     }
60 
61     /// Enable accepting compressed requests.
62     ///
63     /// If a request with an unsupported encoding is received the server will respond with
64     /// [`Code::UnUnimplemented`](crate::Code).
65     ///
66     /// # Example
67     ///
68     /// The most common way of using this is through a server generated by tonic-build:
69     ///
70     /// ```rust
71     /// # enum CompressionEncoding { Gzip }
72     /// # struct Svc;
73     /// # struct ExampleServer<T>(T);
74     /// # impl<T> ExampleServer<T> {
75     /// #     fn new(svc: T) -> Self { Self(svc) }
76     /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
77     /// # }
78     /// # #[tonic::async_trait]
79     /// # trait Example {}
80     ///
81     /// #[tonic::async_trait]
82     /// impl Example for Svc {
83     ///     // ...
84     /// }
85     ///
86     /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip);
87     /// ```
accept_compressed(mut self, encoding: CompressionEncoding) -> Self88     pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
89         self.accept_compression_encodings.enable(encoding);
90         self
91     }
92 
93     /// Enable sending compressed responses.
94     ///
95     /// Requires the client to also support receiving compressed responses.
96     ///
97     /// # Example
98     ///
99     /// The most common way of using this is through a server generated by tonic-build:
100     ///
101     /// ```rust
102     /// # enum CompressionEncoding { Gzip }
103     /// # struct Svc;
104     /// # struct ExampleServer<T>(T);
105     /// # impl<T> ExampleServer<T> {
106     /// #     fn new(svc: T) -> Self { Self(svc) }
107     /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
108     /// # }
109     /// # #[tonic::async_trait]
110     /// # trait Example {}
111     ///
112     /// #[tonic::async_trait]
113     /// impl Example for Svc {
114     ///     // ...
115     /// }
116     ///
117     /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip);
118     /// ```
send_compressed(mut self, encoding: CompressionEncoding) -> Self119     pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
120         self.send_compression_encodings.enable(encoding);
121         self
122     }
123 
124     /// Limits the maximum size of a decoded message.
125     ///
126     /// # Example
127     ///
128     /// The most common way of using this is through a server generated by tonic-build:
129     ///
130     /// ```rust
131     /// # struct Svc;
132     /// # struct ExampleServer<T>(T);
133     /// # impl<T> ExampleServer<T> {
134     /// #     fn new(svc: T) -> Self { Self(svc) }
135     /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
136     /// # }
137     /// # #[tonic::async_trait]
138     /// # trait Example {}
139     ///
140     /// #[tonic::async_trait]
141     /// impl Example for Svc {
142     ///     // ...
143     /// }
144     ///
145     /// // Set the limit to 2MB, Defaults to 4MB.
146     /// let limit = 2 * 1024 * 1024;
147     /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit);
148     /// ```
max_decoding_message_size(mut self, limit: usize) -> Self149     pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
150         self.max_decoding_message_size = Some(limit);
151         self
152     }
153 
154     /// Limits the maximum size of a encoded message.
155     ///
156     /// # Example
157     ///
158     /// The most common way of using this is through a server generated by tonic-build:
159     ///
160     /// ```rust
161     /// # struct Svc;
162     /// # struct ExampleServer<T>(T);
163     /// # impl<T> ExampleServer<T> {
164     /// #     fn new(svc: T) -> Self { Self(svc) }
165     /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
166     /// # }
167     /// # #[tonic::async_trait]
168     /// # trait Example {}
169     ///
170     /// #[tonic::async_trait]
171     /// impl Example for Svc {
172     ///     // ...
173     /// }
174     ///
175     /// // Set the limit to 2MB, Defaults to 4MB.
176     /// let limit = 2 * 1024 * 1024;
177     /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit);
178     /// ```
max_encoding_message_size(mut self, limit: usize) -> Self179     pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
180         self.max_encoding_message_size = Some(limit);
181         self
182     }
183 
184     #[doc(hidden)]
apply_compression_config( mut self, accept_encodings: EnabledCompressionEncodings, send_encodings: EnabledCompressionEncodings, ) -> Self185     pub fn apply_compression_config(
186         mut self,
187         accept_encodings: EnabledCompressionEncodings,
188         send_encodings: EnabledCompressionEncodings,
189     ) -> Self {
190         for &encoding in CompressionEncoding::ENCODINGS {
191             if accept_encodings.is_enabled(encoding) {
192                 self = self.accept_compressed(encoding);
193             }
194             if send_encodings.is_enabled(encoding) {
195                 self = self.send_compressed(encoding);
196             }
197         }
198 
199         self
200     }
201 
202     #[doc(hidden)]
apply_max_message_size_config( mut self, max_decoding_message_size: Option<usize>, max_encoding_message_size: Option<usize>, ) -> Self203     pub fn apply_max_message_size_config(
204         mut self,
205         max_decoding_message_size: Option<usize>,
206         max_encoding_message_size: Option<usize>,
207     ) -> Self {
208         if let Some(limit) = max_decoding_message_size {
209             self = self.max_decoding_message_size(limit);
210         }
211         if let Some(limit) = max_encoding_message_size {
212             self = self.max_encoding_message_size(limit);
213         }
214 
215         self
216     }
217 
218     /// Handle a single unary gRPC request.
unary<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<Body> where S: UnaryService<T::Decode, Response = T::Encode>, B: HttpBody + Send + 'static, B::Error: Into<crate::BoxError> + Send,219     pub async fn unary<S, B>(
220         &mut self,
221         mut service: S,
222         req: http::Request<B>,
223     ) -> http::Response<Body>
224     where
225         S: UnaryService<T::Decode, Response = T::Encode>,
226         B: HttpBody + Send + 'static,
227         B::Error: Into<crate::BoxError> + Send,
228     {
229         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
230             req.headers(),
231             self.send_compression_encodings,
232         );
233 
234         let request = match self.map_request_unary(req).await {
235             Ok(r) => r,
236             Err(status) => {
237                 return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
238                     Err(status),
239                     accept_encoding,
240                     SingleMessageCompressionOverride::default(),
241                     self.max_encoding_message_size,
242                 );
243             }
244         };
245 
246         let response = service
247             .call(request)
248             .await
249             .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
250 
251         let compression_override = compression_override_from_response(&response);
252 
253         self.map_response(
254             response,
255             accept_encoding,
256             compression_override,
257             self.max_encoding_message_size,
258         )
259     }
260 
261     /// Handle a server side streaming request.
server_streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<Body> where S: ServerStreamingService<T::Decode, Response = T::Encode>, S::ResponseStream: Send + 'static, B: HttpBody + Send + 'static, B::Error: Into<crate::BoxError> + Send,262     pub async fn server_streaming<S, B>(
263         &mut self,
264         mut service: S,
265         req: http::Request<B>,
266     ) -> http::Response<Body>
267     where
268         S: ServerStreamingService<T::Decode, Response = T::Encode>,
269         S::ResponseStream: Send + 'static,
270         B: HttpBody + Send + 'static,
271         B::Error: Into<crate::BoxError> + Send,
272     {
273         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
274             req.headers(),
275             self.send_compression_encodings,
276         );
277 
278         let request = match self.map_request_unary(req).await {
279             Ok(r) => r,
280             Err(status) => {
281                 return self.map_response::<S::ResponseStream>(
282                     Err(status),
283                     accept_encoding,
284                     SingleMessageCompressionOverride::default(),
285                     self.max_encoding_message_size,
286                 );
287             }
288         };
289 
290         let response = service.call(request).await;
291 
292         self.map_response(
293             response,
294             accept_encoding,
295             // disabling compression of individual stream items must be done on
296             // the items themselves
297             SingleMessageCompressionOverride::default(),
298             self.max_encoding_message_size,
299         )
300     }
301 
302     /// Handle a client side streaming gRPC request.
client_streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<Body> where S: ClientStreamingService<T::Decode, Response = T::Encode>, B: HttpBody + Send + 'static, B::Error: Into<crate::BoxError> + Send + 'static,303     pub async fn client_streaming<S, B>(
304         &mut self,
305         mut service: S,
306         req: http::Request<B>,
307     ) -> http::Response<Body>
308     where
309         S: ClientStreamingService<T::Decode, Response = T::Encode>,
310         B: HttpBody + Send + 'static,
311         B::Error: Into<crate::BoxError> + Send + 'static,
312     {
313         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
314             req.headers(),
315             self.send_compression_encodings,
316         );
317 
318         let request = t!(self.map_request_streaming(req));
319 
320         let response = service
321             .call(request)
322             .await
323             .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
324 
325         let compression_override = compression_override_from_response(&response);
326 
327         self.map_response(
328             response,
329             accept_encoding,
330             compression_override,
331             self.max_encoding_message_size,
332         )
333     }
334 
335     /// Handle a bi-directional streaming gRPC request.
streaming<S, B>( &mut self, mut service: S, req: http::Request<B>, ) -> http::Response<Body> where S: StreamingService<T::Decode, Response = T::Encode> + Send, S::ResponseStream: Send + 'static, B: HttpBody + Send + 'static, B::Error: Into<crate::BoxError> + Send,336     pub async fn streaming<S, B>(
337         &mut self,
338         mut service: S,
339         req: http::Request<B>,
340     ) -> http::Response<Body>
341     where
342         S: StreamingService<T::Decode, Response = T::Encode> + Send,
343         S::ResponseStream: Send + 'static,
344         B: HttpBody + Send + 'static,
345         B::Error: Into<crate::BoxError> + Send,
346     {
347         let accept_encoding = CompressionEncoding::from_accept_encoding_header(
348             req.headers(),
349             self.send_compression_encodings,
350         );
351 
352         let request = t!(self.map_request_streaming(req));
353 
354         let response = service.call(request).await;
355 
356         self.map_response(
357             response,
358             accept_encoding,
359             SingleMessageCompressionOverride::default(),
360             self.max_encoding_message_size,
361         )
362     }
363 
map_request_unary<B>( &mut self, request: http::Request<B>, ) -> Result<Request<T::Decode>, Status> where B: HttpBody + Send + 'static, B::Error: Into<crate::BoxError> + Send,364     async fn map_request_unary<B>(
365         &mut self,
366         request: http::Request<B>,
367     ) -> Result<Request<T::Decode>, Status>
368     where
369         B: HttpBody + Send + 'static,
370         B::Error: Into<crate::BoxError> + Send,
371     {
372         let request_compression_encoding = self.request_encoding_if_supported(&request)?;
373 
374         let (parts, body) = request.into_parts();
375 
376         let mut stream = pin!(Streaming::new_request(
377             self.codec.decoder(),
378             body,
379             request_compression_encoding,
380             self.max_decoding_message_size,
381         ));
382 
383         let message = stream
384             .try_next()
385             .await?
386             .ok_or_else(|| Status::internal("Missing request message."))?;
387 
388         let mut req = Request::from_http_parts(parts, message);
389 
390         if let Some(trailers) = stream.trailers().await? {
391             req.metadata_mut().merge(trailers);
392         }
393 
394         Ok(req)
395     }
396 
map_request_streaming<B>( &mut self, request: http::Request<B>, ) -> Result<Request<Streaming<T::Decode>>, Status> where B: HttpBody + Send + 'static, B::Error: Into<crate::BoxError> + Send,397     fn map_request_streaming<B>(
398         &mut self,
399         request: http::Request<B>,
400     ) -> Result<Request<Streaming<T::Decode>>, Status>
401     where
402         B: HttpBody + Send + 'static,
403         B::Error: Into<crate::BoxError> + Send,
404     {
405         let encoding = self.request_encoding_if_supported(&request)?;
406 
407         let request = request.map(|body| {
408             Streaming::new_request(
409                 self.codec.decoder(),
410                 body,
411                 encoding,
412                 self.max_decoding_message_size,
413             )
414         });
415 
416         Ok(Request::from_http(request))
417     }
418 
map_response<B>( &mut self, response: Result<crate::Response<B>, Status>, accept_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> http::Response<Body> where B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,419     fn map_response<B>(
420         &mut self,
421         response: Result<crate::Response<B>, Status>,
422         accept_encoding: Option<CompressionEncoding>,
423         compression_override: SingleMessageCompressionOverride,
424         max_message_size: Option<usize>,
425     ) -> http::Response<Body>
426     where
427         B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
428     {
429         let response = t!(response);
430 
431         let (mut parts, body) = response.into_http().into_parts();
432 
433         // Set the content type
434         parts
435             .headers
436             .insert(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
437 
438         #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
439         if let Some(encoding) = accept_encoding {
440             // Set the content encoding
441             parts.headers.insert(
442                 crate::codec::compression::ENCODING_HEADER,
443                 encoding.into_header_value(),
444             );
445         }
446 
447         let body = EncodeBody::new_server(
448             self.codec.encoder(),
449             body,
450             accept_encoding,
451             compression_override,
452             max_message_size,
453         );
454 
455         http::Response::from_parts(parts, Body::new(body))
456     }
457 
request_encoding_if_supported<B>( &self, request: &http::Request<B>, ) -> Result<Option<CompressionEncoding>, Status>458     fn request_encoding_if_supported<B>(
459         &self,
460         request: &http::Request<B>,
461     ) -> Result<Option<CompressionEncoding>, Status> {
462         CompressionEncoding::from_encoding_header(
463             request.headers(),
464             self.accept_compression_encodings,
465         )
466     }
467 }
468 
469 impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result470     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471         f.debug_struct("Grpc")
472             .field("codec", &self.codec)
473             .field(
474                 "accept_compression_encodings",
475                 &self.accept_compression_encodings,
476             )
477             .field(
478                 "send_compression_encodings",
479                 &self.send_compression_encodings,
480             )
481             .finish()
482     }
483 }
484 
compression_override_from_response<B, E>( res: &Result<crate::Response<B>, E>, ) -> SingleMessageCompressionOverride485 fn compression_override_from_response<B, E>(
486     res: &Result<crate::Response<B>, E>,
487 ) -> SingleMessageCompressionOverride {
488     res.as_ref()
489         .ok()
490         .and_then(|response| {
491             response
492                 .extensions()
493                 .get::<SingleMessageCompressionOverride>()
494                 .copied()
495         })
496         .unwrap_or_default()
497 }
498