xref: /tonic/tonic/src/transport/channel/service/tls.rs (revision 77cee9c5)
1 use std::fmt;
2 use std::sync::Arc;
3 
4 use hyper_util::rt::TokioIo;
5 use tokio::io::{AsyncRead, AsyncWrite};
6 use tokio_rustls::{
7     rustls::{
8         crypto,
9         pki_types::{ServerName, TrustAnchor},
10         ClientConfig, ConfigBuilder, RootCertStore, WantsVerifier,
11     },
12     TlsConnector as RustlsConnector,
13 };
14 
15 use super::io::BoxedIo;
16 use crate::transport::service::tls::{
17     convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2,
18 };
19 use crate::transport::tls::{Certificate, Identity};
20 
21 #[derive(Clone)]
22 pub(crate) struct TlsConnector {
23     config: Arc<ClientConfig>,
24     domain: Arc<ServerName<'static>>,
25     assume_http2: bool,
26 }
27 
28 impl TlsConnector {
29     #[allow(clippy::too_many_arguments)]
new( ca_certs: Vec<Certificate>, trust_anchors: Vec<TrustAnchor<'static>>, identity: Option<Identity>, domain: &str, assume_http2: bool, use_key_log: bool, #[cfg(feature = "tls-native-roots")] with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, ) -> Result<Self, crate::BoxError>30     pub(crate) fn new(
31         ca_certs: Vec<Certificate>,
32         trust_anchors: Vec<TrustAnchor<'static>>,
33         identity: Option<Identity>,
34         domain: &str,
35         assume_http2: bool,
36         use_key_log: bool,
37         #[cfg(feature = "tls-native-roots")] with_native_roots: bool,
38         #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool,
39     ) -> Result<Self, crate::BoxError> {
40         fn with_provider(
41             provider: Arc<crypto::CryptoProvider>,
42         ) -> ConfigBuilder<ClientConfig, WantsVerifier> {
43             ClientConfig::builder_with_provider(provider)
44                 .with_safe_default_protocol_versions()
45                 .unwrap()
46         }
47 
48         #[allow(unreachable_patterns)]
49         let builder = match crypto::CryptoProvider::get_default() {
50             Some(provider) => with_provider(provider.clone()),
51             #[cfg(feature = "tls-ring")]
52             None => with_provider(Arc::new(crypto::ring::default_provider())),
53             #[cfg(feature = "tls-aws-lc")]
54             None => with_provider(Arc::new(crypto::aws_lc_rs::default_provider())),
55             // somehow tls is enabled, but neither of the crypto features are enabled.
56             _ => ClientConfig::builder(),
57         };
58 
59         let mut roots = RootCertStore::from_iter(trust_anchors);
60 
61         #[cfg(feature = "tls-native-roots")]
62         if with_native_roots {
63             let rustls_native_certs::CertificateResult { certs, errors, .. } =
64                 rustls_native_certs::load_native_certs();
65             if !errors.is_empty() {
66                 tracing::debug!("errors occurred when loading native certs: {errors:?}");
67             }
68             if certs.is_empty() {
69                 return Err(TlsError::NativeCertsNotFound.into());
70             }
71             roots.add_parsable_certificates(certs);
72         }
73 
74         #[cfg(feature = "tls-webpki-roots")]
75         if with_webpki_roots {
76             roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
77         }
78 
79         for cert in ca_certs {
80             roots.add_parsable_certificates(convert_certificate_to_pki_types(&cert)?);
81         }
82 
83         let builder = builder.with_root_certificates(roots);
84         let mut config = match identity {
85             Some(identity) => {
86                 let (client_cert, client_key) = convert_identity_to_pki_types(&identity)?;
87                 builder.with_client_auth_cert(client_cert, client_key)?
88             }
89             None => builder.with_no_client_auth(),
90         };
91 
92         if use_key_log {
93             config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
94         }
95 
96         config.alpn_protocols.push(ALPN_H2.into());
97         Ok(Self {
98             config: Arc::new(config),
99             domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
100             assume_http2,
101         })
102     }
103 
connect<I>(&self, io: I) -> Result<BoxedIo, crate::BoxError> where I: AsyncRead + AsyncWrite + Send + Unpin + 'static,104     pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::BoxError>
105     where
106         I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
107     {
108         let io = RustlsConnector::from(self.config.clone())
109             .connect(self.domain.as_ref().to_owned(), io)
110             .await?;
111 
112         // Generally we require ALPN to be negotiated, but if the user has
113         // explicitly set `assume_http2` to true, we'll allow it to be missing.
114         let (_, session) = io.get_ref();
115         let alpn_protocol = session.alpn_protocol();
116         if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
117             return Err(TlsError::H2NotNegotiated.into());
118         }
119         Ok(BoxedIo::new(TokioIo::new(io)))
120     }
121 }
122 
123 impl fmt::Debug for TlsConnector {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result124     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125         f.debug_struct("TlsConnector").finish()
126     }
127 }
128