1*694e553bSDave Bakker use anyhow::{Context, Result, anyhow};
2*694e553bSDave Bakker use core::str;
3*694e553bSDave Bakker use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network};
4*694e553bSDave Bakker use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
5*694e553bSDave Bakker use test_programs::wasi::tls::types::ClientHandshake;
6*694e553bSDave Bakker
7*694e553bSDave Bakker const PORT: u16 = 443;
8*694e553bSDave Bakker
test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()>9*694e553bSDave Bakker fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
10*694e553bSDave Bakker let request = format!(
11*694e553bSDave Bakker "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"
12*694e553bSDave Bakker );
13*694e553bSDave Bakker
14*694e553bSDave Bakker let net = Network::default();
15*694e553bSDave Bakker
16*694e553bSDave Bakker let socket = TcpSocket::new(ip.family()).unwrap();
17*694e553bSDave Bakker let (tcp_input, tcp_output) = socket
18*694e553bSDave Bakker .blocking_connect(&net, IpSocketAddress::new(ip, PORT))
19*694e553bSDave Bakker .context("tcp connect failed")?;
20*694e553bSDave Bakker
21*694e553bSDave Bakker let (client_connection, tls_input, tls_output) =
22*694e553bSDave Bakker ClientHandshake::new(domain, tcp_input, tcp_output)
23*694e553bSDave Bakker .blocking_finish()
24*694e553bSDave Bakker .context("tls handshake failed")?;
25*694e553bSDave Bakker
26*694e553bSDave Bakker tls_output
27*694e553bSDave Bakker .blocking_write_util(request.as_bytes())
28*694e553bSDave Bakker .context("writing http request failed")?;
29*694e553bSDave Bakker let response = tls_input
30*694e553bSDave Bakker .blocking_read_to_end()
31*694e553bSDave Bakker .context("reading http response failed")?;
32*694e553bSDave Bakker client_connection
33*694e553bSDave Bakker .blocking_close_output(&tls_output)
34*694e553bSDave Bakker .context("closing tls connection failed")?;
35*694e553bSDave Bakker socket.shutdown(ShutdownType::Both)?;
36*694e553bSDave Bakker
37*694e553bSDave Bakker let response = String::from_utf8(response)?;
38*694e553bSDave Bakker if response.contains("HTTP/1.1 200 OK") {
39*694e553bSDave Bakker Ok(())
40*694e553bSDave Bakker } else {
41*694e553bSDave Bakker Err(anyhow!("server did not respond with 200 OK: {response}"))
42*694e553bSDave Bakker }
43*694e553bSDave Bakker }
44*694e553bSDave Bakker
45*694e553bSDave Bakker /// This test sets up a TCP connection using one domain, and then attempts to
46*694e553bSDave Bakker /// perform a TLS handshake using another unrelated domain. This should result
47*694e553bSDave Bakker /// in a handshake error.
test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()>48*694e553bSDave Bakker fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
49*694e553bSDave Bakker const BAD_DOMAIN: &'static str = "wrongdomain.localhost";
50*694e553bSDave Bakker
51*694e553bSDave Bakker let net = Network::default();
52*694e553bSDave Bakker
53*694e553bSDave Bakker let socket = TcpSocket::new(ip.family()).unwrap();
54*694e553bSDave Bakker let (tcp_input, tcp_output) = socket
55*694e553bSDave Bakker .blocking_connect(&net, IpSocketAddress::new(ip, PORT))
56*694e553bSDave Bakker .context("tcp connect failed")?;
57*694e553bSDave Bakker
58*694e553bSDave Bakker match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() {
59*694e553bSDave Bakker Err(e) => {
60*694e553bSDave Bakker let debug_string = e.to_debug_string();
61*694e553bSDave Bakker // We're expecting an error regarding certificates in some form or
62*694e553bSDave Bakker // another. When we add more TLS backends this naive check will
63*694e553bSDave Bakker // likely need to be revisited/expanded:
64*694e553bSDave Bakker if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") {
65*694e553bSDave Bakker return Ok(());
66*694e553bSDave Bakker }
67*694e553bSDave Bakker Err(e.into())
68*694e553bSDave Bakker }
69*694e553bSDave Bakker Ok(_) => panic!("expecting server name mismatch"),
70*694e553bSDave Bakker }
71*694e553bSDave Bakker }
72*694e553bSDave Bakker
try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>)73*694e553bSDave Bakker fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) {
74*694e553bSDave Bakker // since this is testing remote endpoints to ensure system cert store works
75*694e553bSDave Bakker // the test uses a couple different endpoints to reduce the number of flakes
76*694e553bSDave Bakker const DOMAINS: &'static [&'static str] = &[
77*694e553bSDave Bakker "example.com",
78*694e553bSDave Bakker "api.github.com",
79*694e553bSDave Bakker "docs.wasmtime.dev",
80*694e553bSDave Bakker "bytecodealliance.org",
81*694e553bSDave Bakker "www.rust-lang.org",
82*694e553bSDave Bakker ];
83*694e553bSDave Bakker
84*694e553bSDave Bakker let net = Network::default();
85*694e553bSDave Bakker
86*694e553bSDave Bakker for &domain in DOMAINS {
87*694e553bSDave Bakker let result = (|| {
88*694e553bSDave Bakker let ip = net
89*694e553bSDave Bakker .permissive_blocking_resolve_addresses(domain)?
90*694e553bSDave Bakker .first()
91*694e553bSDave Bakker .map(|a| a.to_owned())
92*694e553bSDave Bakker .ok_or_else(|| anyhow!("DNS lookup failed."))?;
93*694e553bSDave Bakker test(&domain, ip)
94*694e553bSDave Bakker })();
95*694e553bSDave Bakker
96*694e553bSDave Bakker match result {
97*694e553bSDave Bakker Ok(()) => return,
98*694e553bSDave Bakker Err(e) => {
99*694e553bSDave Bakker eprintln!("test for {domain} failed: {e:#}");
100*694e553bSDave Bakker }
101*694e553bSDave Bakker }
102*694e553bSDave Bakker }
103*694e553bSDave Bakker
104*694e553bSDave Bakker panic!("all tests failed");
105*694e553bSDave Bakker }
106*694e553bSDave Bakker
main()107*694e553bSDave Bakker fn main() {
108*694e553bSDave Bakker println!("sample app");
109*694e553bSDave Bakker try_live_endpoints(test_tls_sample_application);
110*694e553bSDave Bakker println!("invalid cert");
111*694e553bSDave Bakker try_live_endpoints(test_tls_invalid_certificate);
112*694e553bSDave Bakker }
113