1 use anyhow::{Context as _, Result, anyhow};
2 use core::future::Future;
3 use test_programs::p3::wasi::sockets::ip_name_lookup::resolve_addresses;
4 use test_programs::p3::wasi::sockets::types::{IpAddress, IpSocketAddress, TcpSocket};
5 use test_programs::p3::wasi::tls::client::Connector;
6 use test_programs::p3::wit_stream;
7 
8 struct Component;
9 
10 test_programs::p3::export!(Component);
11 
12 const PORT: u16 = 443;
13 
test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()>14 async fn test_tls_sample_application(domain: &str, ip: IpAddress) -> Result<()> {
15     let request = format!(
16         "GET / HTTP/1.1\r\nHost: {domain}\r\nUser-Agent: wasmtime-wasi-rust\r\nConnection: close\r\n\r\n"
17     );
18 
19     let sock = TcpSocket::create(ip.family()).unwrap();
20     sock.connect(IpSocketAddress::new(ip, PORT))
21         .await
22         .context("tcp connect failed")?;
23 
24     let conn = Connector::new();
25 
26     let (sock_rx, sock_rx_fut) = sock.receive();
27     let (tls_rx, tls_rx_fut) = conn.receive(sock_rx);
28 
29     let (mut data_tx, data_rx) = wit_stream::new();
30     let (tls_tx, tls_tx_err_fut) = conn.send(data_rx);
31     let sock_tx_fut = sock.send(tls_tx);
32 
33     Connector::connect(conn, domain.into())
34         .await
35         .context("tls handshake failed")?;
36     let buf = data_tx.write_all(request.into()).await;
37     assert!(buf.is_empty());
38 
39     let response = tls_rx.collect().await;
40     let response = String::from_utf8(response)?;
41     if !response.contains("HTTP/1.1 200 OK") {
42         return Err(anyhow!("server did not respond with 200 OK: {response}"));
43     }
44     drop(data_tx);
45     sock_rx_fut.await.context("tcp recv")?;
46     sock_tx_fut.await.context("tcp send")?;
47     tls_rx_fut.await.context("tls recv")?;
48     tls_tx_err_fut.await.context("tls send")?;
49 
50     Ok(())
51 }
52 
53 /// This test sets up a TCP connection using one domain, and then attempts to
54 /// perform a TLS handshake using another unrelated domain. This should result
55 /// in a handshake error.
test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()>56 async fn test_tls_invalid_certificate(_domain: &str, ip: IpAddress) -> Result<()> {
57     const BAD_DOMAIN: &str = "wrongdomain.localhost";
58 
59     let sock = TcpSocket::create(ip.family()).unwrap();
60     sock.connect(IpSocketAddress::new(ip, PORT))
61         .await
62         .context("tcp connect failed")?;
63 
64     let (_, data_rx) = wit_stream::new();
65     let conn = Connector::new();
66 
67     conn.receive(sock.receive().0);
68     sock.send(conn.send(data_rx).0);
69 
70     match Connector::connect(conn, BAD_DOMAIN.into()).await {
71         Err(e) => {
72             let debug_string = e.to_debug_string();
73             // We're expecting an error regarding certificates in some form or
74             // another. When we add more TLS backends this naive check will
75             // likely need to be revisited/expanded:
76             if debug_string.contains("certificate") || debug_string.contains("HandshakeFailure") {
77                 return Ok(());
78             }
79             Err(anyhow!(debug_string))
80         }
81         Ok(_) => panic!("expecting server name mismatch"),
82     }
83 }
84 
try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut) where Fut: Future<Output = Result<()>> + 'a,85 async fn try_live_endpoints<'a, Fut>(test: impl Fn(&'a str, IpAddress) -> Fut)
86 where
87     Fut: Future<Output = Result<()>> + 'a,
88 {
89     // since this is testing remote endpoints to ensure system cert store works
90     // the test uses a couple different endpoints to reduce the number of flakes
91     const DOMAINS: &[&str] = &[
92         "example.com",
93         "api.github.com",
94         "docs.wasmtime.dev",
95         "bytecodealliance.org",
96         "www.rust-lang.org",
97     ];
98 
99     for &domain in DOMAINS {
100         let result = (|| async {
101             let ip = resolve_addresses(domain.into())
102                 .await?
103                 .first()
104                 .map(|a| a.to_owned())
105                 .ok_or_else(|| anyhow!("DNS lookup failed."))?;
106             test(domain, ip).await
107         })();
108 
109         match result.await {
110             Ok(()) => return,
111             Err(e) => {
112                 eprintln!("test for {domain} failed: {e:#}");
113             }
114         }
115     }
116 
117     panic!("all tests failed");
118 }
119 
120 impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
run() -> Result<(), ()>121     async fn run() -> Result<(), ()> {
122         println!("sample app");
123         try_live_endpoints(test_tls_sample_application).await;
124         println!("invalid cert");
125         try_live_endpoints(test_tls_invalid_certificate).await;
126         Ok(())
127     }
128 }
129 
main()130 fn main() {}
131