1*694e553bSDave Bakker use anyhow::{Context as _, Result, anyhow};
2*694e553bSDave Bakker use core::future::Future;
3*694e553bSDave Bakker use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses;
4*694e553bSDave Bakker use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket};
5*694e553bSDave Bakker use test_programs::p3::wasi::tls::client::Connector;
6*694e553bSDave Bakker use test_programs::p3::wit_stream;
7*694e553bSDave Bakker 
8*694e553bSDave Bakker struct Component;
9*694e553bSDave Bakker 
10*694e553bSDave Bakker test_programs::p3::export!(Component);
11*694e553bSDave Bakker 
12*694e553bSDave Bakker const PORT: u16 = 443;
13*694e553bSDave Bakker 
test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()>14*694e553bSDave Bakker async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
15*694e553bSDave Bakker     let request = format!(
16*694e553bSDave Bakker         "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"
17*694e553bSDave Bakker     );
18*694e553bSDave Bakker 
19*694e553bSDave Bakker     let sock = TcpSocket::create(ip.family()).unwrap();
20*694e553bSDave Bakker     sock.connect(IpSocketAddress::new(ip, PORT))
21*694e553bSDave Bakker         .await
22*694e553bSDave Bakker         .context("tcp connect failed")?;
23*694e553bSDave Bakker 
24*694e553bSDave Bakker     let conn = Connector::new();
25*694e553bSDave Bakker 
26*694e553bSDave Bakker     let (sock_rx, sock_rx_fut) = sock.receive();
27*694e553bSDave Bakker     let (tls_rx, tls_rx_fut) = conn.receive(sock_rx);
28*694e553bSDave Bakker 
29*694e553bSDave Bakker     let (mut data_tx, data_rx) = wit_stream::new();
30*694e553bSDave Bakker     let (tls_tx, tls_tx_err_fut) = conn.send(data_rx);
31*694e553bSDave Bakker     let sock_tx_fut = sock.send(tls_tx);
32*694e553bSDave Bakker 
33*694e553bSDave Bakker     Connector::connect(conn, domain.into())
34*694e553bSDave Bakker         .await
35*694e553bSDave Bakker         .context("tls handshake failed")?;
36*694e553bSDave Bakker     let buf = data_tx.write_all(request.into()).await;
37*694e553bSDave Bakker     assert!(buf.is_empty());
38*694e553bSDave Bakker 
39*694e553bSDave Bakker     let response = tls_rx.collect().await;
40*694e553bSDave Bakker     let response = String::from_utf8(response)?;
41*694e553bSDave Bakker     if !response.contains("HTTP/1.1 200 OK") {
42*694e553bSDave Bakker         return Err(anyhow!("server did not respond with 200 OK: {response}"));
43*694e553bSDave Bakker     }
44*694e553bSDave Bakker     drop(data_tx);
45*694e553bSDave Bakker     sock_rx_fut.await.context("tcp recv")?;
46*694e553bSDave Bakker     sock_tx_fut.await.context("tcp send")?;
47*694e553bSDave Bakker     tls_rx_fut.await.context("tls recv")?;
48*694e553bSDave Bakker     tls_tx_err_fut.await.context("tls send")?;
49*694e553bSDave Bakker 
50*694e553bSDave Bakker     Ok(())
51*694e553bSDave Bakker }
52*694e553bSDave Bakker 
53*694e553bSDave Bakker /// This test sets up a TCP connection using one domain, and then attempts to
54*694e553bSDave Bakker /// perform a TLS handshake using another unrelated domain. This should result
55*694e553bSDave Bakker /// in a handshake error.
test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()>56*694e553bSDave Bakker async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
57*694e553bSDave Bakker     const BAD_DOMAIN: &str = "wrongdomain.localhost";
58*694e553bSDave Bakker 
59*694e553bSDave Bakker     let sock = TcpSocket::create(ip.family()).unwrap();
60*694e553bSDave Bakker     sock.connect(IpSocketAddress::new(ip, PORT))
61*694e553bSDave Bakker         .await
62*694e553bSDave Bakker         .context("tcp connect failed")?;
63*694e553bSDave Bakker 
64*694e553bSDave Bakker     let (_, data_rx) = wit_stream::new();
65*694e553bSDave Bakker     let conn = Connector::new();
66*694e553bSDave Bakker 
67*694e553bSDave Bakker     conn.receive(sock.receive().0);
68*694e553bSDave Bakker     sock.send(conn.send(data_rx).0);
69*694e553bSDave Bakker 
70*694e553bSDave Bakker     match Connector::connect(conn, BAD_DOMAIN.into()).await {
71*694e553bSDave Bakker         Err(e) => {
72*694e553bSDave Bakker             let debug_string = e.to_debug_string();
73*694e553bSDave Bakker             // We're expecting an error regarding certificates in some form or
74*694e553bSDave Bakker             // another. When we add more TLS backends this naive check will
75*694e553bSDave Bakker             // likely need to be revisited/expanded:
76*694e553bSDave Bakker             if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") {
77*694e553bSDave Bakker                 return Ok(());
78*694e553bSDave Bakker             }
79*694e553bSDave Bakker             Err(anyhow!(debug_string))
80*694e553bSDave Bakker         }
81*694e553bSDave Bakker         Ok(_) => panic!("expecting server name mismatch"),
82*694e553bSDave Bakker     }
83*694e553bSDave Bakker }
84*694e553bSDave Bakker 
try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) where Fut: Future<Output = Result<()>> + 'a,85*694e553bSDave Bakker async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut)
86*694e553bSDave Bakker where
87*694e553bSDave Bakker     Fut: Future<Output = Result<()>> + 'a,
88*694e553bSDave Bakker {
89*694e553bSDave Bakker     // since this is testing remote endpoints to ensure system cert store works
90*694e553bSDave Bakker     // the test uses a couple different endpoints to reduce the number of flakes
91*694e553bSDave Bakker     const DOMAINS: &[&str] = &[
92*694e553bSDave Bakker         "example.com",
93*694e553bSDave Bakker         "api.github.com",
94*694e553bSDave Bakker         "docs.wasmtime.dev",
95*694e553bSDave Bakker         "bytecodealliance.org",
96*694e553bSDave Bakker         "www.rust-lang.org",
97*694e553bSDave Bakker     ];
98*694e553bSDave Bakker 
99*694e553bSDave Bakker     for &domain in DOMAINS {
100*694e553bSDave Bakker         let result = (|| async {
101*694e553bSDave Bakker             let ip = resolve_addresses(domain.into())
102*694e553bSDave Bakker                 .await?
103*694e553bSDave Bakker                 .first()
104*694e553bSDave Bakker                 .map(|a| a.to_owned())
105*694e553bSDave Bakker                 .ok_or_else(|| anyhow!("DNS lookup failed."))?;
106*694e553bSDave Bakker             test(domain, ip).await
107*694e553bSDave Bakker         })();
108*694e553bSDave Bakker 
109*694e553bSDave Bakker         match result.await {
110*694e553bSDave Bakker             Ok(()) => return,
111*694e553bSDave Bakker             Err(e) => {
112*694e553bSDave Bakker                 eprintln!("test for {domain} failed: {e:#}");
113*694e553bSDave Bakker             }
114*694e553bSDave Bakker         }
115*694e553bSDave Bakker     }
116*694e553bSDave Bakker 
117*694e553bSDave Bakker     panic!("all tests failed");
118*694e553bSDave Bakker }
119*694e553bSDave Bakker 
120*694e553bSDave Bakker impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>121*694e553bSDave Bakker     async fn run() -> Result<(), ()> {
122*694e553bSDave Bakker         println!("sample app");
123*694e553bSDave Bakker         try_live_endpoints(test_tls_sample_application).await;
124*694e553bSDave Bakker         println!("invalid cert");
125*694e553bSDave Bakker         try_live_endpoints(test_tls_invalid_certificate).await;
126*694e553bSDave Bakker         Ok(())
127*694e553bSDave Bakker     }
128*694e553bSDave Bakker }
129*694e553bSDave Bakker 
main()130*694e553bSDave Bakker fn main() {}
131