xref: /tonic/tonic/src/request.rs (revision f4a879db)
1 use crate::metadata::{MetadataMap, MetadataValue};
2 #[cfg(feature = "server")]
3 use crate::transport::server::TcpConnectInfo;
4 #[cfg(all(feature = "server", feature = "_tls-any"))]
5 use crate::transport::server::TlsConnectInfo;
6 use http::Extensions;
7 #[cfg(feature = "server")]
8 use std::net::SocketAddr;
9 #[cfg(all(feature = "server", feature = "_tls-any"))]
10 use std::sync::Arc;
11 use std::time::Duration;
12 #[cfg(all(feature = "server", feature = "_tls-any"))]
13 use tokio_rustls::rustls::pki_types::CertificateDer;
14 use tokio_stream::Stream;
15 
16 /// A gRPC request and metadata from an RPC call.
17 #[derive(Debug)]
18 pub struct Request<T> {
19     metadata: MetadataMap,
20     message: T,
21     extensions: Extensions,
22 }
23 
24 /// Trait implemented by RPC request types.
25 ///
26 /// Types implementing this trait can be used as arguments to client RPC
27 /// methods without explicitly wrapping them into `tonic::Request`s. The purpose
28 /// is to make client calls slightly more convenient to write.
29 ///
30 /// Tonic's code generation and blanket implementations handle this for you,
31 /// so it is not necessary to implement this trait directly.
32 ///
33 /// # Example
34 ///
35 /// Given the following gRPC method definition:
36 /// ```proto
37 /// rpc GetFeature(Point) returns (Feature) {}
38 /// ```
39 ///
40 /// we can call `get_feature` in two equivalent ways:
41 /// ```rust
42 /// # pub struct Point {}
43 /// # pub struct Client {}
44 /// # impl Client {
45 /// #   fn get_feature(&self, r: impl tonic::IntoRequest<Point>) {}
46 /// # }
47 /// # let client = Client {};
48 /// use tonic::Request;
49 ///
50 /// client.get_feature(Point {});
51 /// client.get_feature(Request::new(Point {}));
52 /// ```
53 pub trait IntoRequest<T>: sealed::Sealed {
54     /// Wrap the input message `T` in a `tonic::Request`
into_request(self) -> Request<T>55     fn into_request(self) -> Request<T>;
56 }
57 
58 /// Trait implemented by RPC streaming request types.
59 ///
60 /// Types implementing this trait can be used as arguments to client streaming
61 /// RPC methods without explicitly wrapping them into `tonic::Request`s. The
62 /// purpose is to make client calls slightly more convenient to write.
63 ///
64 /// Tonic's code generation and blanket implementations handle this for you,
65 /// so it is not necessary to implement this trait directly.
66 ///
67 /// # Example
68 ///
69 /// Given the following gRPC service method definition:
70 /// ```proto
71 /// rpc RecordRoute(stream Point) returns (RouteSummary) {}
72 /// ```
73 /// we can call `record_route` in two equivalent ways:
74 ///
75 /// ```rust
76 /// # #[derive(Clone)]
77 /// # pub struct Point {};
78 /// # pub struct Client {};
79 /// # impl Client {
80 /// #   fn record_route(&self, r: impl tonic::IntoStreamingRequest<Message = Point>) {}
81 /// # }
82 /// # let client = Client {};
83 /// use tonic::Request;
84 ///
85 /// let messages = vec![Point {}, Point {}];
86 ///
87 /// client.record_route(Request::new(tokio_stream::iter(messages.clone())));
88 /// client.record_route(tokio_stream::iter(messages));
89 /// ```
90 pub trait IntoStreamingRequest: sealed::Sealed {
91     /// The RPC request stream type
92     type Stream: Stream<Item = Self::Message> + Send + 'static;
93 
94     /// The RPC request type
95     type Message;
96 
97     /// Wrap the stream of messages in a `tonic::Request`
into_streaming_request(self) -> Request<Self::Stream>98     fn into_streaming_request(self) -> Request<Self::Stream>;
99 }
100 
101 impl<T> Request<T> {
102     /// Create a new gRPC request.
103     ///
104     /// ```rust
105     /// # use tonic::Request;
106     /// # pub struct HelloRequest {
107     /// #   pub name: String,
108     /// # }
109     /// Request::new(HelloRequest {
110     ///    name: "Bob".into(),
111     /// });
112     /// ```
new(message: T) -> Self113     pub fn new(message: T) -> Self {
114         Request {
115             metadata: MetadataMap::new(),
116             message,
117             extensions: Extensions::new(),
118         }
119     }
120 
121     /// Get a reference to the message
get_ref(&self) -> &T122     pub fn get_ref(&self) -> &T {
123         &self.message
124     }
125 
126     /// Get a mutable reference to the message
get_mut(&mut self) -> &mut T127     pub fn get_mut(&mut self) -> &mut T {
128         &mut self.message
129     }
130 
131     /// Get a reference to the custom request metadata.
metadata(&self) -> &MetadataMap132     pub fn metadata(&self) -> &MetadataMap {
133         &self.metadata
134     }
135 
136     /// Get a mutable reference to the request metadata.
metadata_mut(&mut self) -> &mut MetadataMap137     pub fn metadata_mut(&mut self) -> &mut MetadataMap {
138         &mut self.metadata
139     }
140 
141     /// Consumes `self`, returning the message
into_inner(self) -> T142     pub fn into_inner(self) -> T {
143         self.message
144     }
145 
146     /// Consumes `self` returning the parts of the request.
into_parts(self) -> (MetadataMap, Extensions, T)147     pub fn into_parts(self) -> (MetadataMap, Extensions, T) {
148         (self.metadata, self.extensions, self.message)
149     }
150 
151     /// Create a new gRPC request from metadata, extensions and message.
from_parts(metadata: MetadataMap, extensions: Extensions, message: T) -> Self152     pub fn from_parts(metadata: MetadataMap, extensions: Extensions, message: T) -> Self {
153         Self {
154             metadata,
155             extensions,
156             message,
157         }
158     }
159 
from_http_parts(parts: http::request::Parts, message: T) -> Self160     pub(crate) fn from_http_parts(parts: http::request::Parts, message: T) -> Self {
161         Request {
162             metadata: MetadataMap::from_headers(parts.headers),
163             message,
164             extensions: parts.extensions,
165         }
166     }
167 
168     /// Convert an HTTP request to a gRPC request
from_http(http: http::Request<T>) -> Self169     pub fn from_http(http: http::Request<T>) -> Self {
170         let (parts, message) = http.into_parts();
171         Request::from_http_parts(parts, message)
172     }
173 
into_http( self, uri: http::Uri, method: http::Method, version: http::Version, sanitize_headers: SanitizeHeaders, ) -> http::Request<T>174     pub(crate) fn into_http(
175         self,
176         uri: http::Uri,
177         method: http::Method,
178         version: http::Version,
179         sanitize_headers: SanitizeHeaders,
180     ) -> http::Request<T> {
181         let mut request = http::Request::new(self.message);
182 
183         *request.version_mut() = version;
184         *request.method_mut() = method;
185         *request.uri_mut() = uri;
186         *request.headers_mut() = match sanitize_headers {
187             SanitizeHeaders::Yes => self.metadata.into_sanitized_headers(),
188             SanitizeHeaders::No => self.metadata.into_headers(),
189         };
190         *request.extensions_mut() = self.extensions;
191 
192         request
193     }
194 
195     #[doc(hidden)]
map<F, U>(self, f: F) -> Request<U> where F: FnOnce(T) -> U,196     pub fn map<F, U>(self, f: F) -> Request<U>
197     where
198         F: FnOnce(T) -> U,
199     {
200         let message = f(self.message);
201 
202         Request {
203             metadata: self.metadata,
204             message,
205             extensions: self.extensions,
206         }
207     }
208 
209     /// Get the local address of this connection.
210     ///
211     /// This will return `None` if the `IO` type used
212     /// does not implement `Connected` or when using a unix domain socket.
213     /// This currently only works on the server side.
214     #[cfg(feature = "server")]
local_addr(&self) -> Option<SocketAddr>215     pub fn local_addr(&self) -> Option<SocketAddr> {
216         let addr = self
217             .extensions()
218             .get::<TcpConnectInfo>()
219             .and_then(|i| i.local_addr());
220 
221         #[cfg(feature = "_tls-any")]
222         let addr = addr.or_else(|| {
223             self.extensions()
224                 .get::<TlsConnectInfo<TcpConnectInfo>>()
225                 .and_then(|i| i.get_ref().local_addr())
226         });
227 
228         addr
229     }
230 
231     /// Get the remote address of this connection.
232     ///
233     /// This will return `None` if the `IO` type used
234     /// does not implement `Connected` or when using a unix domain socket.
235     /// This currently only works on the server side.
236     #[cfg(feature = "server")]
remote_addr(&self) -> Option<SocketAddr>237     pub fn remote_addr(&self) -> Option<SocketAddr> {
238         let addr = self
239             .extensions()
240             .get::<TcpConnectInfo>()
241             .and_then(|i| i.remote_addr());
242 
243         #[cfg(feature = "_tls-any")]
244         let addr = addr.or_else(|| {
245             self.extensions()
246                 .get::<TlsConnectInfo<TcpConnectInfo>>()
247                 .and_then(|i| i.get_ref().remote_addr())
248         });
249 
250         addr
251     }
252 
253     /// Get the peer certificates of the connected client.
254     ///
255     /// This is used to fetch the certificates from the TLS session
256     /// and is mostly used for mTLS. This currently only returns
257     /// `Some` on the server side of the `transport` server with
258     /// TLS enabled connections.
259     #[cfg(all(feature = "server", feature = "_tls-any"))]
peer_certs(&self) -> Option<Arc<Vec<CertificateDer<'static>>>>260     pub fn peer_certs(&self) -> Option<Arc<Vec<CertificateDer<'static>>>> {
261         self.extensions()
262             .get::<TlsConnectInfo<TcpConnectInfo>>()
263             .and_then(|i| i.peer_certs())
264     }
265 
266     /// Set the max duration the request is allowed to take.
267     ///
268     /// Requires the server to support the `grpc-timeout` metadata, which Tonic does.
269     ///
270     /// The duration will be formatted according to [the spec] and use the most precise unit
271     /// possible.
272     ///
273     /// Example:
274     ///
275     /// ```rust
276     /// use std::time::Duration;
277     /// use tonic::Request;
278     ///
279     /// let mut request = Request::new(());
280     ///
281     /// request.set_timeout(Duration::from_secs(30));
282     ///
283     /// let value = request.metadata().get("grpc-timeout").unwrap();
284     ///
285     /// assert_eq!(
286     ///     value,
287     ///     // equivalent to 30 seconds
288     ///     "30000000u"
289     /// );
290     /// ```
291     ///
292     /// [the spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
set_timeout(&mut self, deadline: Duration)293     pub fn set_timeout(&mut self, deadline: Duration) {
294         let value: MetadataValue<_> = duration_to_grpc_timeout(deadline).parse().unwrap();
295         self.metadata_mut()
296             .insert(crate::metadata::GRPC_TIMEOUT_HEADER, value);
297     }
298 
299     /// Returns a reference to the associated extensions.
extensions(&self) -> &Extensions300     pub fn extensions(&self) -> &Extensions {
301         &self.extensions
302     }
303 
304     /// Returns a mutable reference to the associated extensions.
305     ///
306     /// # Example
307     ///
308     /// Extensions can be set in interceptors:
309     ///
310     /// ```no_run
311     /// use tonic::{Request, Status};
312     ///
313     /// #[derive(Clone)] // Extensions must be Clone
314     /// struct MyExtension {
315     ///     some_piece_of_data: String,
316     /// }
317     ///
318     /// fn intercept(mut request: Request<()>) -> Result<Request<()>, Status> {
319     ///     request.extensions_mut().insert(MyExtension {
320     ///         some_piece_of_data: "foo".to_string(),
321     ///     });
322     ///
323     ///     Ok(request)
324     /// }
325     /// ```
326     ///
327     /// And picked up by RPCs:
328     ///
329     /// ```no_run
330     /// use tonic::{async_trait, Status, Request, Response};
331     /// #
332     /// # struct Output {}
333     /// # struct Input;
334     /// # struct MyService;
335     /// # struct MyExtension;
336     /// # #[async_trait]
337     /// # trait TestService {
338     /// #     async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status>;
339     /// # }
340     ///
341     /// #[async_trait]
342     /// impl TestService for MyService {
343     ///     async fn handler(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
344     ///         let value: &MyExtension = req.extensions().get::<MyExtension>().unwrap();
345     ///
346     ///         Ok(Response::new(Output {}))
347     ///     }
348     /// }
349     /// ```
extensions_mut(&mut self) -> &mut Extensions350     pub fn extensions_mut(&mut self) -> &mut Extensions {
351         &mut self.extensions
352     }
353 }
354 
355 impl<T> IntoRequest<T> for T {
into_request(self) -> Request<Self>356     fn into_request(self) -> Request<Self> {
357         Request::new(self)
358     }
359 }
360 
361 impl<T> IntoRequest<T> for Request<T> {
into_request(self) -> Request<T>362     fn into_request(self) -> Request<T> {
363         self
364     }
365 }
366 
367 impl<T> IntoStreamingRequest for T
368 where
369     T: Stream + Send + 'static,
370 {
371     type Stream = T;
372     type Message = T::Item;
373 
into_streaming_request(self) -> Request<Self>374     fn into_streaming_request(self) -> Request<Self> {
375         Request::new(self)
376     }
377 }
378 
379 impl<T> IntoStreamingRequest for Request<T>
380 where
381     T: Stream + Send + 'static,
382 {
383     type Stream = T;
384     type Message = T::Item;
385 
into_streaming_request(self) -> Self386     fn into_streaming_request(self) -> Self {
387         self
388     }
389 }
390 
391 impl<T> sealed::Sealed for T {}
392 
393 mod sealed {
394     pub trait Sealed {}
395 }
396 
duration_to_grpc_timeout(duration: Duration) -> String397 fn duration_to_grpc_timeout(duration: Duration) -> String {
398     fn try_format<T: Into<u128>>(
399         duration: Duration,
400         unit: char,
401         convert: impl FnOnce(Duration) -> T,
402     ) -> Option<String> {
403         // The gRPC spec specifies that the timeout most be at most 8 digits. So this is the largest a
404         // value can be before we need to use a bigger unit.
405         let max_size: u128 = 99_999_999; // exactly 8 digits
406 
407         let value = convert(duration).into();
408         if value > max_size {
409             None
410         } else {
411             Some(format!("{}{}", value, unit))
412         }
413     }
414 
415     // pick the most precise unit that is less than or equal to 8 digits as per the gRPC spec
416     try_format(duration, 'n', |d| d.as_nanos())
417         .or_else(|| try_format(duration, 'u', |d| d.as_micros()))
418         .or_else(|| try_format(duration, 'm', |d| d.as_millis()))
419         .or_else(|| try_format(duration, 'S', |d| d.as_secs()))
420         .or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60))
421         .or_else(|| {
422             try_format(duration, 'H', |d| {
423                 let minutes = d.as_secs() / 60;
424                 minutes / 60
425             })
426         })
427         // duration has to be more than 11_415 years for this to happen
428         .expect("duration is unrealistically large")
429 }
430 
431 /// When converting a `tonic::Request` into a `http::Request` should reserved
432 /// headers be removed?
433 pub(crate) enum SanitizeHeaders {
434     Yes,
435     No,
436 }
437 
438 #[cfg(test)]
439 mod tests {
440     use super::*;
441     use crate::metadata::{MetadataKey, MetadataValue};
442 
443     use http::Uri;
444 
445     #[test]
reserved_headers_are_excluded()446     fn reserved_headers_are_excluded() {
447         let mut r = Request::new(1);
448 
449         for header in &MetadataMap::GRPC_RESERVED_HEADERS {
450             r.metadata_mut().insert(
451                 MetadataKey::unchecked_from_header_name(header.clone()),
452                 MetadataValue::from_static("invalid"),
453             );
454         }
455 
456         let http_request = r.into_http(
457             Uri::default(),
458             http::Method::POST,
459             http::Version::HTTP_2,
460             SanitizeHeaders::Yes,
461         );
462         assert!(http_request.headers().is_empty());
463     }
464 
465     #[test]
duration_to_grpc_timeout_less_than_second()466     fn duration_to_grpc_timeout_less_than_second() {
467         let timeout = Duration::from_millis(500);
468         let value = duration_to_grpc_timeout(timeout);
469         assert_eq!(value, format!("{}u", timeout.as_micros()));
470     }
471 
472     #[test]
duration_to_grpc_timeout_more_than_second()473     fn duration_to_grpc_timeout_more_than_second() {
474         let timeout = Duration::from_secs(30);
475         let value = duration_to_grpc_timeout(timeout);
476         assert_eq!(value, format!("{}u", timeout.as_micros()));
477     }
478 
479     #[test]
duration_to_grpc_timeout_a_very_long_time()480     fn duration_to_grpc_timeout_a_very_long_time() {
481         let one_hour = Duration::from_secs(60 * 60);
482         let value = duration_to_grpc_timeout(one_hour);
483         assert_eq!(value, format!("{}m", one_hour.as_millis()));
484     }
485 }
486