1*694e553bSDave Bakker use anyhow::{Context, Result, anyhow};
2*694e553bSDave Bakker use core::str;
3*694e553bSDave Bakker use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network};
4*694e553bSDave Bakker use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
5*694e553bSDave Bakker use test_programs::wasi::tls::types::ClientHandshake;
6*694e553bSDave Bakker 
7*694e553bSDave Bakker const PORT: u16 = 443;
8*694e553bSDave Bakker 
test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()>9*694e553bSDave Bakker fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
10*694e553bSDave Bakker     let request = format!(
11*694e553bSDave Bakker         "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"
12*694e553bSDave Bakker     );
13*694e553bSDave Bakker 
14*694e553bSDave Bakker     let net = Network::default();
15*694e553bSDave Bakker 
16*694e553bSDave Bakker     let socket = TcpSocket::new(ip.family()).unwrap();
17*694e553bSDave Bakker     let (tcp_input, tcp_output) = socket
18*694e553bSDave Bakker         .blocking_connect(&net, IpSocketAddress::new(ip, PORT))
19*694e553bSDave Bakker         .context("tcp connect failed")?;
20*694e553bSDave Bakker 
21*694e553bSDave Bakker     let (client_connection, tls_input, tls_output) =
22*694e553bSDave Bakker         ClientHandshake::new(domain, tcp_input, tcp_output)
23*694e553bSDave Bakker             .blocking_finish()
24*694e553bSDave Bakker             .context("tls handshake failed")?;
25*694e553bSDave Bakker 
26*694e553bSDave Bakker     tls_output
27*694e553bSDave Bakker         .blocking_write_util(request.as_bytes())
28*694e553bSDave Bakker         .context("writing http request failed")?;
29*694e553bSDave Bakker     let response = tls_input
30*694e553bSDave Bakker         .blocking_read_to_end()
31*694e553bSDave Bakker         .context("reading http response failed")?;
32*694e553bSDave Bakker     client_connection
33*694e553bSDave Bakker         .blocking_close_output(&tls_output)
34*694e553bSDave Bakker         .context("closing tls connection failed")?;
35*694e553bSDave Bakker     socket.shutdown(ShutdownType::Both)?;
36*694e553bSDave Bakker 
37*694e553bSDave Bakker     let response = String::from_utf8(response)?;
38*694e553bSDave Bakker     if response.contains("HTTP/1.1 200 OK") {
39*694e553bSDave Bakker         Ok(())
40*694e553bSDave Bakker     } else {
41*694e553bSDave Bakker         Err(anyhow!("server did not respond with 200 OK: {response}"))
42*694e553bSDave Bakker     }
43*694e553bSDave Bakker }
44*694e553bSDave Bakker 
45*694e553bSDave Bakker /// This test sets up a TCP connection using one domain, and then attempts to
46*694e553bSDave Bakker /// perform a TLS handshake using another unrelated domain. This should result
47*694e553bSDave Bakker /// in a handshake error.
test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()>48*694e553bSDave Bakker fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
49*694e553bSDave Bakker     const BAD_DOMAIN: &'static str = "wrongdomain.localhost";
50*694e553bSDave Bakker 
51*694e553bSDave Bakker     let net = Network::default();
52*694e553bSDave Bakker 
53*694e553bSDave Bakker     let socket = TcpSocket::new(ip.family()).unwrap();
54*694e553bSDave Bakker     let (tcp_input, tcp_output) = socket
55*694e553bSDave Bakker         .blocking_connect(&net, IpSocketAddress::new(ip, PORT))
56*694e553bSDave Bakker         .context("tcp connect failed")?;
57*694e553bSDave Bakker 
58*694e553bSDave Bakker     match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() {
59*694e553bSDave Bakker         Err(e) => {
60*694e553bSDave Bakker             let debug_string = e.to_debug_string();
61*694e553bSDave Bakker             // We're expecting an error regarding certificates in some form or
62*694e553bSDave Bakker             // another. When we add more TLS backends this naive check will
63*694e553bSDave Bakker             // likely need to be revisited/expanded:
64*694e553bSDave Bakker             if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") {
65*694e553bSDave Bakker                 return Ok(());
66*694e553bSDave Bakker             }
67*694e553bSDave Bakker             Err(e.into())
68*694e553bSDave Bakker         }
69*694e553bSDave Bakker         Ok(_) => panic!("expecting server name mismatch"),
70*694e553bSDave Bakker     }
71*694e553bSDave Bakker }
72*694e553bSDave Bakker 
try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>)73*694e553bSDave Bakker fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) {
74*694e553bSDave Bakker     // since this is testing remote endpoints to ensure system cert store works
75*694e553bSDave Bakker     // the test uses a couple different endpoints to reduce the number of flakes
76*694e553bSDave Bakker     const DOMAINS: &'static [&'static str] = &[
77*694e553bSDave Bakker         "example.com",
78*694e553bSDave Bakker         "api.github.com",
79*694e553bSDave Bakker         "docs.wasmtime.dev",
80*694e553bSDave Bakker         "bytecodealliance.org",
81*694e553bSDave Bakker         "www.rust-lang.org",
82*694e553bSDave Bakker     ];
83*694e553bSDave Bakker 
84*694e553bSDave Bakker     let net = Network::default();
85*694e553bSDave Bakker 
86*694e553bSDave Bakker     for &domain in DOMAINS {
87*694e553bSDave Bakker         let result = (|| {
88*694e553bSDave Bakker             let ip = net
89*694e553bSDave Bakker                 .permissive_blocking_resolve_addresses(domain)?
90*694e553bSDave Bakker                 .first()
91*694e553bSDave Bakker                 .map(|a| a.to_owned())
92*694e553bSDave Bakker                 .ok_or_else(|| anyhow!("DNS lookup failed."))?;
93*694e553bSDave Bakker             test(&domain, ip)
94*694e553bSDave Bakker         })();
95*694e553bSDave Bakker 
96*694e553bSDave Bakker         match result {
97*694e553bSDave Bakker             Ok(()) => return,
98*694e553bSDave Bakker             Err(e) => {
99*694e553bSDave Bakker                 eprintln!("test for {domain} failed: {e:#}");
100*694e553bSDave Bakker             }
101*694e553bSDave Bakker         }
102*694e553bSDave Bakker     }
103*694e553bSDave Bakker 
104*694e553bSDave Bakker     panic!("all tests failed");
105*694e553bSDave Bakker }
106*694e553bSDave Bakker 
main()107*694e553bSDave Bakker fn main() {
108*694e553bSDave Bakker     println!("sample app");
109*694e553bSDave Bakker     try_live_endpoints(test_tls_sample_application);
110*694e553bSDave Bakker     println!("invalid cert");
111*694e553bSDave Bakker     try_live_endpoints(test_tls_invalid_certificate);
112*694e553bSDave Bakker }
113