1 use http::header::CONTENT_LENGTH;
2 use hyper::service::service_fn;
3 use hyper::{Request, Response};
4 use std::future::Future;
5 use std::net::{SocketAddr, TcpStream};
6 use std::thread::JoinHandle;
7 use tokio::net::TcpListener;
8 use tracing::{debug, trace, warn};
9 use wasmtime::{Result, error::Context as _};
10 use wasmtime_wasi_http::io::TokioIo;
11 
test( req: Request<hyper::body::Incoming>, ) -> http::Result<Response<hyper::body::Incoming>>12 async fn test(
13     req: Request<hyper::body::Incoming>,
14 ) -> http::Result<Response<hyper::body::Incoming>> {
15     debug!(?req, "preparing mocked response for request");
16     let method = req.method().to_string();
17     let uri = req.uri().to_string();
18     let resp = Response::builder()
19         .header("x-wasmtime-test-method", method)
20         .header("x-wasmtime-test-uri", uri);
21     let resp = if let Some(content_length) = req.headers().get(CONTENT_LENGTH) {
22         resp.header(CONTENT_LENGTH, content_length)
23     } else {
24         resp
25     };
26     let body = req.into_body();
27     resp.body(body)
28 }
29 
30 pub struct Server {
31     conns: usize,
32     addr: SocketAddr,
33     worker: Option<JoinHandle<()>>,
34 }
35 
36 impl Server {
new<F>( conns: usize, run: impl Fn(TokioIo<tokio::net::TcpStream>) -> F + Send + 'static, ) -> Result<Self> where F: Future<Output = Result<()>>,37     fn new<F>(
38         conns: usize,
39         run: impl Fn(TokioIo<tokio::net::TcpStream>) -> F + Send + 'static,
40     ) -> Result<Self>
41     where
42         F: Future<Output = Result<()>>,
43     {
44         let thread = std::thread::spawn(|| -> Result<_> {
45             let rt = tokio::runtime::Builder::new_current_thread()
46                 .enable_all()
47                 .build()
48                 .context("failed to start tokio runtime")?;
49             let listener = rt.block_on(async move {
50                 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
51                 TcpListener::bind(addr).await.context("failed to bind")
52             })?;
53             Ok((rt, listener))
54         });
55         let (rt, listener) = thread.join().unwrap()?;
56         let addr = listener.local_addr().context("failed to get local addr")?;
57         let worker = std::thread::spawn(move || {
58             debug!("dedicated thread to start listening");
59             rt.block_on(async move {
60                 for i in 0..conns {
61                     debug!(i, "preparing to accept connection");
62                     match listener.accept().await {
63                         Ok((stream, ..)) => {
64                             debug!(i, "accepted connection");
65                             if let Err(err) = run(TokioIo::new(stream)).await {
66                                 warn!(i, ?err, "failed to serve connection");
67                             }
68                         }
69                         Err(err) => {
70                             warn!(i, ?err, "failed to accept connection");
71                         }
72                     };
73                 }
74             })
75         });
76         Ok(Self {
77             conns,
78             worker: Some(worker),
79             addr,
80         })
81     }
82 
http1(conns: usize) -> Result<Self>83     pub fn http1(conns: usize) -> Result<Self> {
84         debug!("initializing http1 server");
85         Self::new(conns, |io| async move {
86             let mut builder = hyper::server::conn::http1::Builder::new();
87             let http = builder.keep_alive(false).pipeline_flush(true);
88 
89             debug!("preparing to bind connection to service");
90             let conn = http.serve_connection(io, service_fn(test)).await;
91             trace!("connection result {:?}", conn);
92             conn?;
93             Ok(())
94         })
95     }
96 
http2(conns: usize) -> Result<Self>97     pub fn http2(conns: usize) -> Result<Self> {
98         debug!("initializing http2 server");
99         Self::new(conns, |io| async move {
100             let mut builder = hyper::server::conn::http2::Builder::new(TokioExecutor);
101             let http = builder.max_concurrent_streams(20);
102 
103             debug!("preparing to bind connection to service");
104             let conn = http.serve_connection(io, service_fn(test)).await;
105             trace!("connection result {:?}", conn);
106             if let Err(e) = &conn {
107                 let message = e.to_string();
108                 if message.contains("connection closed before reading preface")
109                     || message.contains("unspecific protocol error detected")
110                 {
111                     return Ok(());
112                 }
113             }
114             conn?;
115             Ok(())
116         })
117     }
118 
addr(&self) -> String119     pub fn addr(&self) -> String {
120         format!("localhost:{}", self.addr.port())
121     }
122 }
123 
124 impl Drop for Server {
drop(&mut self)125     fn drop(&mut self) {
126         debug!("shutting down http1 server");
127         for _ in 0..self.conns {
128             // Force a connection to happen in case one hasn't happened already.
129             let _ = TcpStream::connect(&self.addr);
130         }
131         self.worker.take().unwrap().join().unwrap();
132     }
133 }
134 
135 #[derive(Clone)]
136 /// An Executor that uses the tokio runtime.
137 struct TokioExecutor;
138 
139 impl<F> hyper::rt::Executor<F> for TokioExecutor
140 where
141     F: Future + Send + 'static,
142     F::Output: Send + 'static,
143 {
execute(&self, fut: F)144     fn execute(&self, fut: F) {
145         tokio::task::spawn(fut);
146     }
147 }
148