1 use anyhow::{Context, Result, anyhow};
2 use core::str;
3 use test_programs::wasi::sockets::network::{IpAddress, IpSocketAddress, Network};
4 use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
5 use test_programs::wasi::tls::types::ClientHandshake;
6 
7 const PORT: u16 = 443;
8 
test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()>9 fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
10     let request = format!(
11         "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"
12     );
13 
14     let net = Network::default();
15 
16     let socket = TcpSocket::new(ip.family()).unwrap();
17     let (tcp_input, tcp_output) = socket
18         .blocking_connect(&net, IpSocketAddress::new(ip, PORT))
19         .context("tcp connect failed")?;
20 
21     let (client_connection, tls_input, tls_output) =
22         ClientHandshake::new(domain, tcp_input, tcp_output)
23             .blocking_finish()
24             .context("tls handshake failed")?;
25 
26     tls_output
27         .blocking_write_util(request.as_bytes())
28         .context("writing http request failed")?;
29     let response = tls_input
30         .blocking_read_to_end()
31         .context("reading http response failed")?;
32     client_connection
33         .blocking_close_output(&tls_output)
34         .context("closing tls connection failed")?;
35     socket.shutdown(ShutdownType::Both)?;
36 
37     let response = String::from_utf8(response)?;
38     if response.contains("HTTP/1.1 200 OK") {
39         Ok(())
40     } else {
41         Err(anyhow!("server did not respond with 200 OK: {response}"))
42     }
43 }
44 
45 /// This test sets up a TCP connection using one domain, and then attempts to
46 /// perform a TLS handshake using another unrelated domain. This should result
47 /// in a handshake error.
test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()>48 fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
49     const BAD_DOMAIN: &'static str = "wrongdomain.localhost";
50 
51     let net = Network::default();
52 
53     let socket = TcpSocket::new(ip.family()).unwrap();
54     let (tcp_input, tcp_output) = socket
55         .blocking_connect(&net, IpSocketAddress::new(ip, PORT))
56         .context("tcp connect failed")?;
57 
58     match ClientHandshake::new(BAD_DOMAIN, tcp_input, tcp_output).blocking_finish() {
59         Err(e) => {
60             let debug_string = e.to_debug_string();
61             // We're expecting an error regarding certificates in some form or
62             // another. When we add more TLS backends this naive check will
63             // likely need to be revisited/expanded:
64             if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") {
65                 return Ok(());
66             }
67             Err(e.into())
68         }
69         Ok(_) => panic!("expecting server name mismatch"),
70     }
71 }
72 
try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>)73 fn try_live_endpoints(test: impl Fn(&str, IpAddress) -> Result<()>) {
74     // since this is testing remote endpoints to ensure system cert store works
75     // the test uses a couple different endpoints to reduce the number of flakes
76     const DOMAINS: &'static [&'static str] = &[
77         "example.com",
78         "api.github.com",
79         "docs.wasmtime.dev",
80         "bytecodealliance.org",
81         "www.rust-lang.org",
82     ];
83 
84     let net = Network::default();
85 
86     for &domain in DOMAINS {
87         let result = (|| {
88             let ip = net
89                 .permissive_blocking_resolve_addresses(domain)?
90                 .first()
91                 .map(|a| a.to_owned())
92                 .ok_or_else(|| anyhow!("DNS lookup failed."))?;
93             test(&domain, ip)
94         })();
95 
96         match result {
97             Ok(()) => return,
98             Err(e) => {
99                 eprintln!("test for {domain} failed: {e:#}");
100             }
101         }
102     }
103 
104     panic!("all tests failed");
105 }
106 
main()107 fn main() {
108     println!("sample app");
109     try_live_endpoints(test_tls_sample_application);
110     println!("invalid cert");
111     try_live_endpoints(test_tls_invalid_certificate);
112 }
113