xref: /tonic/tonic/src/transport/channel/endpoint.rs (revision 6839a39a)
1 #[cfg(feature = "_tls-any")]
2 use super::service::TlsConnector;
3 use super::service::{self, Executor, SharedExec};
4 use super::Channel;
5 #[cfg(feature = "_tls-any")]
6 use super::ClientTlsConfig;
7 use crate::transport::Error;
8 use bytes::Bytes;
9 use http::{uri::Uri, HeaderValue};
10 use hyper::rt;
11 use hyper_util::client::legacy::connect::HttpConnector;
12 use std::{fmt, future::Future, net::IpAddr, pin::Pin, str::FromStr, time::Duration};
13 use tower_service::Service;
14 
15 /// Channel builder.
16 ///
17 /// This struct is used to build and configure HTTP/2 channels.
18 #[derive(Clone)]
19 pub struct Endpoint {
20     pub(crate) uri: Uri,
21     pub(crate) origin: Option<Uri>,
22     pub(crate) user_agent: Option<HeaderValue>,
23     pub(crate) timeout: Option<Duration>,
24     pub(crate) concurrency_limit: Option<usize>,
25     pub(crate) rate_limit: Option<(u64, Duration)>,
26     #[cfg(feature = "_tls-any")]
27     pub(crate) tls: Option<TlsConnector>,
28     pub(crate) buffer_size: Option<usize>,
29     pub(crate) init_stream_window_size: Option<u32>,
30     pub(crate) init_connection_window_size: Option<u32>,
31     pub(crate) tcp_keepalive: Option<Duration>,
32     pub(crate) tcp_nodelay: bool,
33     pub(crate) http2_keep_alive_interval: Option<Duration>,
34     pub(crate) http2_keep_alive_timeout: Option<Duration>,
35     pub(crate) http2_keep_alive_while_idle: Option<bool>,
36     pub(crate) http2_max_header_list_size: Option<u32>,
37     pub(crate) connect_timeout: Option<Duration>,
38     pub(crate) http2_adaptive_window: Option<bool>,
39     pub(crate) local_address: Option<IpAddr>,
40     pub(crate) executor: SharedExec,
41 }
42 
43 impl Endpoint {
44     // FIXME: determine if we want to expose this or not. This is really
45     // just used in codegen for a shortcut.
46     #[doc(hidden)]
new<D>(dst: D) -> Result<Self, Error> where D: TryInto<Self>, D::Error: Into<crate::BoxError>,47     pub fn new<D>(dst: D) -> Result<Self, Error>
48     where
49         D: TryInto<Self>,
50         D::Error: Into<crate::BoxError>,
51     {
52         let me = dst.try_into().map_err(|e| Error::from_source(e.into()))?;
53         #[cfg(feature = "_tls-any")]
54         if me.uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
55             return me.tls_config(ClientTlsConfig::new().with_enabled_roots());
56         }
57 
58         Ok(me)
59     }
60 
61     /// Convert an `Endpoint` from a static string.
62     ///
63     /// # Panics
64     ///
65     /// This function panics if the argument is an invalid URI.
66     ///
67     /// ```
68     /// # use tonic::transport::Endpoint;
69     /// Endpoint::from_static("https://example.com");
70     /// ```
from_static(s: &'static str) -> Self71     pub fn from_static(s: &'static str) -> Self {
72         let uri = Uri::from_static(s);
73         Self::from(uri)
74     }
75 
76     /// Convert an `Endpoint` from shared bytes.
77     ///
78     /// ```
79     /// # use tonic::transport::Endpoint;
80     /// Endpoint::from_shared("https://example.com".to_string());
81     /// ```
from_shared(s: impl Into<Bytes>) -> Result<Self, Error>82     pub fn from_shared(s: impl Into<Bytes>) -> Result<Self, Error> {
83         let uri = Uri::from_maybe_shared(s.into()).map_err(|e| Error::new_invalid_uri().with(e))?;
84         Ok(Self::from(uri))
85     }
86 
87     /// Set a custom user-agent header.
88     ///
89     /// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
90     /// It must be a value that can be converted into a valid  `http::HeaderValue` or building
91     /// the endpoint will fail.
92     /// ```
93     /// # use tonic::transport::Endpoint;
94     /// # let mut builder = Endpoint::from_static("https://example.com");
95     /// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
96     /// // user-agent: "Greeter tonic/x.x.x"
97     /// ```
user_agent<T>(self, user_agent: T) -> Result<Self, Error> where T: TryInto<HeaderValue>,98     pub fn user_agent<T>(self, user_agent: T) -> Result<Self, Error>
99     where
100         T: TryInto<HeaderValue>,
101     {
102         user_agent
103             .try_into()
104             .map(|ua| Endpoint {
105                 user_agent: Some(ua),
106                 ..self
107             })
108             .map_err(|_| Error::new_invalid_user_agent())
109     }
110 
111     /// Set a custom origin.
112     ///
113     /// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
114     /// which serves multiple services at the same time.
115     /// It will play the role of SNI (Server Name Indication).
116     ///
117     /// ```
118     /// # use tonic::transport::Endpoint;
119     /// # let mut builder = Endpoint::from_static("https://proxy.com");
120     /// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
121     /// // origin: "https://example.com"
122     /// ```
origin(self, origin: Uri) -> Self123     pub fn origin(self, origin: Uri) -> Self {
124         Endpoint {
125             origin: Some(origin),
126             ..self
127         }
128     }
129 
130     /// Apply a timeout to each request.
131     ///
132     /// ```
133     /// # use tonic::transport::Endpoint;
134     /// # use std::time::Duration;
135     /// # let mut builder = Endpoint::from_static("https://example.com");
136     /// builder.timeout(Duration::from_secs(5));
137     /// ```
138     ///
139     /// # Notes
140     ///
141     /// This does **not** set the timeout metadata (`grpc-timeout` header) on
142     /// the request, meaning the server will not be informed of this timeout,
143     /// for that use [`Request::set_timeout`].
144     ///
145     /// [`Request::set_timeout`]: crate::Request::set_timeout
timeout(self, dur: Duration) -> Self146     pub fn timeout(self, dur: Duration) -> Self {
147         Endpoint {
148             timeout: Some(dur),
149             ..self
150         }
151     }
152 
153     /// Apply a timeout to connecting to the uri.
154     ///
155     /// Defaults to no timeout.
156     ///
157     /// ```
158     /// # use tonic::transport::Endpoint;
159     /// # use std::time::Duration;
160     /// # let mut builder = Endpoint::from_static("https://example.com");
161     /// builder.connect_timeout(Duration::from_secs(5));
162     /// ```
connect_timeout(self, dur: Duration) -> Self163     pub fn connect_timeout(self, dur: Duration) -> Self {
164         Endpoint {
165             connect_timeout: Some(dur),
166             ..self
167         }
168     }
169 
170     /// Set whether TCP keepalive messages are enabled on accepted connections.
171     ///
172     /// If `None` is specified, keepalive is disabled, otherwise the duration
173     /// specified will be the time to remain idle before sending TCP keepalive
174     /// probes.
175     ///
176     /// Default is no keepalive (`None`)
177     ///
tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self178     pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
179         Endpoint {
180             tcp_keepalive,
181             ..self
182         }
183     }
184 
185     /// Apply a concurrency limit to each request.
186     ///
187     /// ```
188     /// # use tonic::transport::Endpoint;
189     /// # let mut builder = Endpoint::from_static("https://example.com");
190     /// builder.concurrency_limit(256);
191     /// ```
concurrency_limit(self, limit: usize) -> Self192     pub fn concurrency_limit(self, limit: usize) -> Self {
193         Endpoint {
194             concurrency_limit: Some(limit),
195             ..self
196         }
197     }
198 
199     /// Apply a rate limit to each request.
200     ///
201     /// ```
202     /// # use tonic::transport::Endpoint;
203     /// # use std::time::Duration;
204     /// # let mut builder = Endpoint::from_static("https://example.com");
205     /// builder.rate_limit(32, Duration::from_secs(1));
206     /// ```
rate_limit(self, limit: u64, duration: Duration) -> Self207     pub fn rate_limit(self, limit: u64, duration: Duration) -> Self {
208         Endpoint {
209             rate_limit: Some((limit, duration)),
210             ..self
211         }
212     }
213 
214     /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2
215     /// stream-level flow control.
216     ///
217     /// Default is 65,535
218     ///
219     /// [spec]: https://httpwg.org/specs/rfc9113.html#InitialWindowSize
initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self220     pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
221         Endpoint {
222             init_stream_window_size: sz.into(),
223             ..self
224         }
225     }
226 
227     /// Sets the max connection-level flow control for HTTP2
228     ///
229     /// Default is 65,535
initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self230     pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
231         Endpoint {
232             init_connection_window_size: sz.into(),
233             ..self
234         }
235     }
236 
237     /// Sets the tower service default internal buffer size
238     ///
239     /// Default is 1024
buffer_size(self, sz: impl Into<Option<usize>>) -> Self240     pub fn buffer_size(self, sz: impl Into<Option<usize>>) -> Self {
241         Endpoint {
242             buffer_size: sz.into(),
243             ..self
244         }
245     }
246 
247     /// Configures TLS for the endpoint.
248     #[cfg(feature = "_tls-any")]
tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, Error>249     pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result<Self, Error> {
250         Ok(Endpoint {
251             tls: Some(
252                 tls_config
253                     .into_tls_connector(&self.uri)
254                     .map_err(Error::from_source)?,
255             ),
256             ..self
257         })
258     }
259 
260     /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
tcp_nodelay(self, enabled: bool) -> Self261     pub fn tcp_nodelay(self, enabled: bool) -> Self {
262         Endpoint {
263             tcp_nodelay: enabled,
264             ..self
265         }
266     }
267 
268     /// Set http2 KEEP_ALIVE_INTERVAL. Uses `hyper`'s default otherwise.
http2_keep_alive_interval(self, interval: Duration) -> Self269     pub fn http2_keep_alive_interval(self, interval: Duration) -> Self {
270         Endpoint {
271             http2_keep_alive_interval: Some(interval),
272             ..self
273         }
274     }
275 
276     /// Set http2 KEEP_ALIVE_TIMEOUT. Uses `hyper`'s default otherwise.
keep_alive_timeout(self, duration: Duration) -> Self277     pub fn keep_alive_timeout(self, duration: Duration) -> Self {
278         Endpoint {
279             http2_keep_alive_timeout: Some(duration),
280             ..self
281         }
282     }
283 
284     /// Set http2 KEEP_ALIVE_WHILE_IDLE. Uses `hyper`'s default otherwise.
keep_alive_while_idle(self, enabled: bool) -> Self285     pub fn keep_alive_while_idle(self, enabled: bool) -> Self {
286         Endpoint {
287             http2_keep_alive_while_idle: Some(enabled),
288             ..self
289         }
290     }
291 
292     /// Sets whether to use an adaptive flow control. Uses `hyper`'s default otherwise.
http2_adaptive_window(self, enabled: bool) -> Self293     pub fn http2_adaptive_window(self, enabled: bool) -> Self {
294         Endpoint {
295             http2_adaptive_window: Some(enabled),
296             ..self
297         }
298     }
299 
300     /// Sets the max size of received header frames.
301     ///
302     /// This will default to whatever the default in hyper is. As of v1.4.1, it is 16 KiB.
http2_max_header_list_size(self, size: u32) -> Self303     pub fn http2_max_header_list_size(self, size: u32) -> Self {
304         Endpoint {
305             http2_max_header_list_size: Some(size),
306             ..self
307         }
308     }
309 
310     /// Sets the executor used to spawn async tasks.
311     ///
312     /// Uses `tokio::spawn` by default.
executor<E>(mut self, executor: E) -> Self where E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,313     pub fn executor<E>(mut self, executor: E) -> Self
314     where
315         E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
316     {
317         self.executor = SharedExec::new(executor);
318         self
319     }
320 
connector<C>(&self, c: C) -> service::Connector<C>321     pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
322         service::Connector::new(
323             c,
324             #[cfg(feature = "_tls-any")]
325             self.tls.clone(),
326         )
327     }
328 
329     /// Set the local address.
330     ///
331     /// This sets the IP address the client will use. By default we let hyper select the IP address.
local_address(self, addr: Option<IpAddr>) -> Self332     pub fn local_address(self, addr: Option<IpAddr>) -> Self {
333         Endpoint {
334             local_address: addr,
335             ..self
336         }
337     }
338 
http_connector(&self) -> service::Connector<HttpConnector>339     pub(crate) fn http_connector(&self) -> service::Connector<HttpConnector> {
340         let mut http = HttpConnector::new();
341         http.enforce_http(false);
342         http.set_nodelay(self.tcp_nodelay);
343         http.set_keepalive(self.tcp_keepalive);
344         http.set_connect_timeout(self.connect_timeout);
345         http.set_local_address(self.local_address);
346         self.connector(http)
347     }
348 
349     /// Create a channel from this config.
connect(&self) -> Result<Channel, Error>350     pub async fn connect(&self) -> Result<Channel, Error> {
351         Channel::connect(self.http_connector(), self.clone()).await
352     }
353 
354     /// Create a channel from this config.
355     ///
356     /// The channel returned by this method does not attempt to connect to the endpoint until first
357     /// use.
connect_lazy(&self) -> Channel358     pub fn connect_lazy(&self) -> Channel {
359         Channel::new(self.http_connector(), self.clone())
360     }
361 
362     /// Connect with a custom connector.
363     ///
364     /// This allows you to build a [Channel](struct.Channel.html) that uses a non-HTTP transport.
365     /// See the `uds` example for an example on how to use this function to build channel that
366     /// uses a Unix socket transport.
367     ///
368     /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied.
connect_with_connector<C>(&self, connector: C) -> Result<Channel, Error> where C: Service<Uri> + Send + 'static, C::Response: rt::Read + rt::Write + Send + Unpin, C::Future: Send, crate::BoxError: From<C::Error> + Send,369     pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, Error>
370     where
371         C: Service<Uri> + Send + 'static,
372         C::Response: rt::Read + rt::Write + Send + Unpin,
373         C::Future: Send,
374         crate::BoxError: From<C::Error> + Send,
375     {
376         let connector = self.connector(connector);
377 
378         if let Some(connect_timeout) = self.connect_timeout {
379             let mut connector = hyper_timeout::TimeoutConnector::new(connector);
380             connector.set_connect_timeout(Some(connect_timeout));
381             Channel::connect(connector, self.clone()).await
382         } else {
383             Channel::connect(connector, self.clone()).await
384         }
385     }
386 
387     /// Connect with a custom connector lazily.
388     ///
389     /// This allows you to build a [Channel](struct.Channel.html) that uses a non-HTTP transport
390     /// connect to it lazily.
391     ///
392     /// See the `uds` example for an example on how to use this function to build channel that
393     /// uses a Unix socket transport.
connect_with_connector_lazy<C>(&self, connector: C) -> Channel where C: Service<Uri> + Send + 'static, C::Response: rt::Read + rt::Write + Send + Unpin, C::Future: Send, crate::BoxError: From<C::Error> + Send,394     pub fn connect_with_connector_lazy<C>(&self, connector: C) -> Channel
395     where
396         C: Service<Uri> + Send + 'static,
397         C::Response: rt::Read + rt::Write + Send + Unpin,
398         C::Future: Send,
399         crate::BoxError: From<C::Error> + Send,
400     {
401         let connector = self.connector(connector);
402         if let Some(connect_timeout) = self.connect_timeout {
403             let mut connector = hyper_timeout::TimeoutConnector::new(connector);
404             connector.set_connect_timeout(Some(connect_timeout));
405             Channel::new(connector, self.clone())
406         } else {
407             Channel::new(connector, self.clone())
408         }
409     }
410 
411     /// Get the endpoint uri.
412     ///
413     /// ```
414     /// # use tonic::transport::Endpoint;
415     /// # use http::Uri;
416     /// let endpoint = Endpoint::from_static("https://example.com");
417     ///
418     /// assert_eq!(endpoint.uri(), &Uri::from_static("https://example.com"));
419     /// ```
uri(&self) -> &Uri420     pub fn uri(&self) -> &Uri {
421         &self.uri
422     }
423 
424     /// Get the value of `TCP_NODELAY` option for accepted connections.
get_tcp_nodelay(&self) -> bool425     pub fn get_tcp_nodelay(&self) -> bool {
426         self.tcp_nodelay
427     }
428 
429     /// Get the connect timeout.
get_connect_timeout(&self) -> Option<Duration>430     pub fn get_connect_timeout(&self) -> Option<Duration> {
431         self.connect_timeout
432     }
433 
434     /// Get whether TCP keepalive messages are enabled on accepted connections.
435     ///
436     /// If `None` is specified, keepalive is disabled, otherwise the duration
437     /// specified will be the time to remain idle before sending TCP keepalive
438     /// probes.
get_tcp_keepalive(&self) -> Option<Duration>439     pub fn get_tcp_keepalive(&self) -> Option<Duration> {
440         self.tcp_keepalive
441     }
442 }
443 
444 impl From<Uri> for Endpoint {
from(uri: Uri) -> Self445     fn from(uri: Uri) -> Self {
446         Self {
447             uri,
448             origin: None,
449             user_agent: None,
450             concurrency_limit: None,
451             rate_limit: None,
452             timeout: None,
453             #[cfg(feature = "_tls-any")]
454             tls: None,
455             buffer_size: None,
456             init_stream_window_size: None,
457             init_connection_window_size: None,
458             tcp_keepalive: None,
459             tcp_nodelay: true,
460             http2_keep_alive_interval: None,
461             http2_keep_alive_timeout: None,
462             http2_keep_alive_while_idle: None,
463             http2_max_header_list_size: None,
464             connect_timeout: None,
465             http2_adaptive_window: None,
466             executor: SharedExec::tokio(),
467             local_address: None,
468         }
469     }
470 }
471 
472 impl TryFrom<Bytes> for Endpoint {
473     type Error = Error;
474 
try_from(t: Bytes) -> Result<Self, Self::Error>475     fn try_from(t: Bytes) -> Result<Self, Self::Error> {
476         Self::from_shared(t)
477     }
478 }
479 
480 impl TryFrom<String> for Endpoint {
481     type Error = Error;
482 
try_from(t: String) -> Result<Self, Self::Error>483     fn try_from(t: String) -> Result<Self, Self::Error> {
484         Self::from_shared(t.into_bytes())
485     }
486 }
487 
488 impl TryFrom<&'static str> for Endpoint {
489     type Error = Error;
490 
try_from(t: &'static str) -> Result<Self, Self::Error>491     fn try_from(t: &'static str) -> Result<Self, Self::Error> {
492         Self::from_shared(t.as_bytes())
493     }
494 }
495 
496 impl fmt::Debug for Endpoint {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result497     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
498         f.debug_struct("Endpoint").finish()
499     }
500 }
501 
502 impl FromStr for Endpoint {
503     type Err = Error;
504 
from_str(s: &str) -> Result<Self, Self::Err>505     fn from_str(s: &str) -> Result<Self, Self::Err> {
506         Self::try_from(s.to_string())
507     }
508 }
509