1 use std::fmt; 2 use std::sync::Arc; 3 4 use hyper_util::rt::TokioIo; 5 use tokio::io::{AsyncRead, AsyncWrite}; 6 use tokio_rustls::{ 7 rustls::{ 8 crypto, 9 pki_types::{ServerName, TrustAnchor}, 10 ClientConfig, ConfigBuilder, RootCertStore, WantsVerifier, 11 }, 12 TlsConnector as RustlsConnector, 13 }; 14 15 use super::io::BoxedIo; 16 use crate::transport::service::tls::{ 17 convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2, 18 }; 19 use crate::transport::tls::{Certificate, Identity}; 20 21 #[derive(Clone)] 22 pub(crate) struct TlsConnector { 23 config: Arc<ClientConfig>, 24 domain: Arc<ServerName<'static>>, 25 assume_http2: bool, 26 } 27 28 impl TlsConnector { 29 #[allow(clippy::too_many_arguments)] new( ca_certs: Vec<Certificate>, trust_anchors: Vec<TrustAnchor<'static>>, identity: Option<Identity>, domain: &str, assume_http2: bool, use_key_log: bool, #[cfg(feature = "tls-native-roots")] with_native_roots: bool, #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, ) -> Result<Self, crate::BoxError>30 pub(crate) fn new( 31 ca_certs: Vec<Certificate>, 32 trust_anchors: Vec<TrustAnchor<'static>>, 33 identity: Option<Identity>, 34 domain: &str, 35 assume_http2: bool, 36 use_key_log: bool, 37 #[cfg(feature = "tls-native-roots")] with_native_roots: bool, 38 #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool, 39 ) -> Result<Self, crate::BoxError> { 40 fn with_provider( 41 provider: Arc<crypto::CryptoProvider>, 42 ) -> ConfigBuilder<ClientConfig, WantsVerifier> { 43 ClientConfig::builder_with_provider(provider) 44 .with_safe_default_protocol_versions() 45 .unwrap() 46 } 47 48 #[allow(unreachable_patterns)] 49 let builder = match crypto::CryptoProvider::get_default() { 50 Some(provider) => with_provider(provider.clone()), 51 #[cfg(feature = "tls-ring")] 52 None => with_provider(Arc::new(crypto::ring::default_provider())), 53 #[cfg(feature = "tls-aws-lc")] 54 None => with_provider(Arc::new(crypto::aws_lc_rs::default_provider())), 55 // somehow tls is enabled, but neither of the crypto features are enabled. 56 _ => ClientConfig::builder(), 57 }; 58 59 let mut roots = RootCertStore::from_iter(trust_anchors); 60 61 #[cfg(feature = "tls-native-roots")] 62 if with_native_roots { 63 let rustls_native_certs::CertificateResult { certs, errors, .. } = 64 rustls_native_certs::load_native_certs(); 65 if !errors.is_empty() { 66 tracing::debug!("errors occurred when loading native certs: {errors:?}"); 67 } 68 if certs.is_empty() { 69 return Err(TlsError::NativeCertsNotFound.into()); 70 } 71 roots.add_parsable_certificates(certs); 72 } 73 74 #[cfg(feature = "tls-webpki-roots")] 75 if with_webpki_roots { 76 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); 77 } 78 79 for cert in ca_certs { 80 roots.add_parsable_certificates(convert_certificate_to_pki_types(&cert)?); 81 } 82 83 let builder = builder.with_root_certificates(roots); 84 let mut config = match identity { 85 Some(identity) => { 86 let (client_cert, client_key) = convert_identity_to_pki_types(&identity)?; 87 builder.with_client_auth_cert(client_cert, client_key)? 88 } 89 None => builder.with_no_client_auth(), 90 }; 91 92 if use_key_log { 93 config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); 94 } 95 96 config.alpn_protocols.push(ALPN_H2.into()); 97 Ok(Self { 98 config: Arc::new(config), 99 domain: Arc::new(ServerName::try_from(domain)?.to_owned()), 100 assume_http2, 101 }) 102 } 103 connect<I>(&self, io: I) -> Result<BoxedIo, crate::BoxError> where I: AsyncRead + AsyncWrite + Send + Unpin + 'static,104 pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::BoxError> 105 where 106 I: AsyncRead + AsyncWrite + Send + Unpin + 'static, 107 { 108 let io = RustlsConnector::from(self.config.clone()) 109 .connect(self.domain.as_ref().to_owned(), io) 110 .await?; 111 112 // Generally we require ALPN to be negotiated, but if the user has 113 // explicitly set `assume_http2` to true, we'll allow it to be missing. 114 let (_, session) = io.get_ref(); 115 let alpn_protocol = session.alpn_protocol(); 116 if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { 117 return Err(TlsError::H2NotNegotiated.into()); 118 } 119 Ok(BoxedIo::new(TokioIo::new(io))) 120 } 121 } 122 123 impl fmt::Debug for TlsConnector { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 125 f.debug_struct("TlsConnector").finish() 126 } 127 } 128