1 use anyhow::{Result, anyhow, bail};
2 use futures::{Future, SinkExt, StreamExt, TryStreamExt, future, stream};
3 use test_programs::wasi::http::types::{
4     Fields, IncomingRequest, IncomingResponse, Method, OutgoingBody, OutgoingRequest,
5     OutgoingResponse, ResponseOutparam, Scheme,
6 };
7 use url::Url;
8 
9 const MAX_CONCURRENCY: usize = 16;
10 
11 struct Handler;
12 
13 test_programs::proxy::export!(Handler);
14 
15 impl test_programs::proxy::exports::wasi::http::incoming_handler::Guest for Handler {
16     fn handle(request: IncomingRequest, response_out: ResponseOutparam) {
17         executor::run(async move {
18             handle_request(request, response_out).await;
19         })
20     }
21 }
22 
23 async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam) {
24     let headers = request.headers().entries();
25 
26     assert!(request.authority().is_some());
27 
28     match (request.method(), request.path_with_query().as_deref()) {
29         (Method::Get, Some("/hash-all")) => {
30             // Send outgoing GET requests to the specified URLs and stream the hashes of the response bodies as
31             // they arrive.
32 
33             let urls = headers.iter().filter_map(|(k, v)| {
34                 (k == "url")
35                     .then_some(v)
36                     .and_then(|v| std::str::from_utf8(v).ok())
37                     .and_then(|v| Url::parse(v).ok())
38             });
39 
40             let results = urls.map(|url| async move {
41                 let result = hash(&url).await;
42                 (url, result)
43             });
44 
45             let mut results = stream::iter(results).buffer_unordered(MAX_CONCURRENCY);
46 
47             let response = OutgoingResponse::new(
48                 Fields::from_list(&[("content-type".to_string(), b"text/plain".to_vec())]).unwrap(),
49             );
50 
51             let mut body =
52                 executor::outgoing_body(response.body().expect("response should be writable"));
53 
54             ResponseOutparam::set(response_out, Ok(response));
55 
56             while let Some((url, result)) = results.next().await {
57                 let payload = match result {
58                     Ok(hash) => format!("{url}: {hash}\n"),
59                     Err(e) => format!("{url}: {e:?}\n"),
60                 }
61                 .into_bytes();
62 
63                 if let Err(e) = body.send(payload).await {
64                     eprintln!("Error sending payload: {e}");
65                 }
66             }
67         }
68 
69         (Method::Post, Some("/echo")) => {
70             // Echo the request body without buffering it.
71 
72             let response = OutgoingResponse::new(
73                 Fields::from_list(
74                     &headers
75                         .into_iter()
76                         .filter_map(|(k, v)| (k == "content-type").then_some((k, v)))
77                         .collect::<Vec<_>>(),
78                 )
79                 .unwrap(),
80             );
81 
82             let mut body =
83                 executor::outgoing_body(response.body().expect("response should be writable"));
84 
85             ResponseOutparam::set(response_out, Ok(response));
86 
87             let mut stream =
88                 executor::incoming_body(request.consume().expect("request should be readable"));
89 
90             while let Some(chunk) = stream.next().await {
91                 match chunk {
92                     Ok(chunk) => {
93                         if let Err(e) = body.send(chunk).await {
94                             eprintln!("Error sending body: {e}");
95                             break;
96                         }
97                     }
98                     Err(e) => {
99                         eprintln!("Error receiving body: {e}");
100                         break;
101                     }
102                 }
103             }
104         }
105 
106         (Method::Post, Some("/double-echo")) => {
107             // Pipe the request body to an outgoing request and stream the response back to the client.
108 
109             if let Some(url) = headers.iter().find_map(|(k, v)| {
110                 (k == "url")
111                     .then_some(v)
112                     .and_then(|v| std::str::from_utf8(v).ok())
113                     .and_then(|v| Url::parse(v).ok())
114             }) {
115                 match double_echo(request, &url).await {
116                     Ok((request_copy, response)) => {
117                         let mut stream = executor::incoming_body(
118                             response.consume().expect("response should be consumable"),
119                         );
120 
121                         let response = OutgoingResponse::new(
122                             Fields::from_list(
123                                 &headers
124                                     .into_iter()
125                                     .filter_map(|(k, v)| (k == "content-type").then_some((k, v)))
126                                     .collect::<Vec<_>>(),
127                             )
128                             .unwrap(),
129                         );
130 
131                         let mut body = executor::outgoing_body(
132                             response.body().expect("response should be writable"),
133                         );
134 
135                         ResponseOutparam::set(response_out, Ok(response));
136 
137                         let response_copy = async move {
138                             while let Some(chunk) = stream.next().await {
139                                 body.send(chunk?).await?;
140                             }
141                             Ok::<_, anyhow::Error>(())
142                         };
143 
144                         let (request_copy, response_copy) =
145                             future::join(request_copy, response_copy).await;
146                         if let Err(e) = request_copy.and(response_copy) {
147                             eprintln!("error piping to and from {url}: {e}");
148                         }
149                     }
150 
151                     Err(e) => {
152                         eprintln!("Error sending outgoing request to {url}: {e}");
153                         server_error(response_out);
154                     }
155                 }
156             } else {
157                 bad_request(response_out);
158             }
159         }
160 
161         _ => method_not_allowed(response_out),
162     }
163 }
164 
165 async fn double_echo(
166     incoming_request: IncomingRequest,
167     url: &Url,
168 ) -> Result<(impl Future<Output = Result<()>> + use<>, IncomingResponse)> {
169     let outgoing_request = OutgoingRequest::new(Fields::new());
170 
171     outgoing_request
172         .set_method(&Method::Post)
173         .map_err(|()| anyhow!("failed to set method"))?;
174 
175     outgoing_request
176         .set_path_with_query(Some(url.path()))
177         .map_err(|()| anyhow!("failed to set path_with_query"))?;
178 
179     outgoing_request
180         .set_scheme(Some(&match url.scheme() {
181             "http" => Scheme::Http,
182             "https" => Scheme::Https,
183             scheme => Scheme::Other(scheme.into()),
184         }))
185         .map_err(|()| anyhow!("failed to set scheme"))?;
186 
187     outgoing_request
188         .set_authority(Some(&format!(
189             "{}{}",
190             url.host_str().unwrap_or(""),
191             if let Some(port) = url.port() {
192                 format!(":{port}")
193             } else {
194                 String::new()
195             }
196         )))
197         .map_err(|()| anyhow!("failed to set authority"))?;
198 
199     let mut body = executor::outgoing_body(
200         outgoing_request
201             .body()
202             .expect("request body should be writable"),
203     );
204 
205     let response = executor::outgoing_request_send(outgoing_request);
206 
207     let mut stream = executor::incoming_body(
208         incoming_request
209             .consume()
210             .expect("request should be consumable"),
211     );
212 
213     let copy = async move {
214         while let Some(chunk) = stream.next().await {
215             body.send(chunk?).await?;
216         }
217         Ok::<_, anyhow::Error>(())
218     };
219 
220     let response = response.await?;
221 
222     let status = response.status();
223 
224     if !(200..300).contains(&status) {
225         bail!("unexpected status: {status}");
226     }
227 
228     Ok((copy, response))
229 }
230 
231 fn server_error(response_out: ResponseOutparam) {
232     respond(500, response_out)
233 }
234 
235 fn bad_request(response_out: ResponseOutparam) {
236     respond(400, response_out)
237 }
238 
239 fn method_not_allowed(response_out: ResponseOutparam) {
240     respond(405, response_out)
241 }
242 
243 fn respond(status: u16, response_out: ResponseOutparam) {
244     let response = OutgoingResponse::new(Fields::new());
245     response
246         .set_status_code(status)
247         .expect("setting status code");
248 
249     let body = response.body().expect("response should be writable");
250 
251     ResponseOutparam::set(response_out, Ok(response));
252 
253     OutgoingBody::finish(body, None).expect("outgoing-body.finish");
254 }
255 
256 async fn hash(url: &Url) -> Result<String> {
257     let request = OutgoingRequest::new(Fields::new());
258 
259     request
260         .set_path_with_query(Some(url.path()))
261         .map_err(|()| anyhow!("failed to set path_with_query"))?;
262     request
263         .set_scheme(Some(&match url.scheme() {
264             "http" => Scheme::Http,
265             "https" => Scheme::Https,
266             scheme => Scheme::Other(scheme.into()),
267         }))
268         .map_err(|()| anyhow!("failed to set scheme"))?;
269     request
270         .set_authority(Some(&format!(
271             "{}{}",
272             url.host_str().unwrap_or(""),
273             if let Some(port) = url.port() {
274                 format!(":{port}")
275             } else {
276                 String::new()
277             }
278         )))
279         .map_err(|()| anyhow!("failed to set authority"))?;
280 
281     let response = executor::outgoing_request_send(request).await?;
282 
283     let status = response.status();
284 
285     if !(200..300).contains(&status) {
286         bail!("unexpected status: {status}");
287     }
288 
289     let mut body =
290         executor::incoming_body(response.consume().expect("response should be readable"));
291 
292     use sha2::Digest;
293     let mut hasher = sha2::Sha256::new();
294     while let Some(chunk) = body.try_next().await? {
295         hasher.update(&chunk);
296     }
297 
298     use base64::Engine;
299     Ok(base64::engine::general_purpose::STANDARD_NO_PAD.encode(hasher.finalize()))
300 }
301 
302 // Technically this should not be here for a proxy, but given the current
303 // framework for tests it's required since this file is built as a `bin`
304 fn main() {}
305 
306 mod executor {
307     use anyhow::{Error, Result, anyhow};
308     use futures::{Sink, Stream, future, sink, stream};
309     use std::{
310         cell::RefCell,
311         future::Future,
312         mem,
313         rc::Rc,
314         sync::{Arc, Mutex},
315         task::{Context, Poll, Wake, Waker},
316     };
317     use test_programs::wasi::{
318         http::{
319             outgoing_handler,
320             types::{
321                 self, FutureTrailers, IncomingBody, IncomingResponse, InputStream, OutgoingBody,
322                 OutgoingRequest, OutputStream,
323             },
324         },
325         io::{self, streams::StreamError},
326     };
327 
328     const READ_SIZE: u64 = 16 * 1024;
329 
330     static WAKERS: Mutex<Vec<(io::poll::Pollable, Waker)>> = Mutex::new(Vec::new());
331 
332     pub fn run<T>(future: impl Future<Output = T>) -> T {
333         futures::pin_mut!(future);
334 
335         struct DummyWaker;
336 
337         impl Wake for DummyWaker {
338             fn wake(self: Arc<Self>) {}
339         }
340 
341         let waker = Arc::new(DummyWaker).into();
342 
343         loop {
344             match future.as_mut().poll(&mut Context::from_waker(&waker)) {
345                 Poll::Pending => {
346                     let mut new_wakers = Vec::new();
347 
348                     let wakers = mem::take::<Vec<_>>(&mut WAKERS.lock().unwrap());
349 
350                     assert!(!wakers.is_empty());
351 
352                     let pollables = wakers
353                         .iter()
354                         .map(|(pollable, _)| pollable)
355                         .collect::<Vec<_>>();
356 
357                     let mut ready = vec![false; wakers.len()];
358 
359                     for index in io::poll::poll(&pollables) {
360                         ready[usize::try_from(index).unwrap()] = true;
361                     }
362 
363                     for (ready, (pollable, waker)) in ready.into_iter().zip(wakers) {
364                         if ready {
365                             waker.wake()
366                         } else {
367                             new_wakers.push((pollable, waker));
368                         }
369                     }
370 
371                     *WAKERS.lock().unwrap() = new_wakers;
372                 }
373                 Poll::Ready(result) => break result,
374             }
375         }
376     }
377 
378     pub fn outgoing_body(body: OutgoingBody) -> impl Sink<Vec<u8>, Error = Error> {
379         struct Outgoing(Option<(OutputStream, OutgoingBody)>);
380 
381         impl Drop for Outgoing {
382             fn drop(&mut self) {
383                 if let Some((stream, body)) = self.0.take() {
384                     drop(stream);
385                     OutgoingBody::finish(body, None).expect("outgoing-body.finish");
386                 }
387             }
388         }
389 
390         let stream = body.write().expect("response body should be writable");
391         let pair = Rc::new(RefCell::new(Outgoing(Some((stream, body)))));
392 
393         sink::unfold((), {
394             move |(), chunk: Vec<u8>| {
395                 future::poll_fn({
396                     let mut offset = 0;
397                     let mut flushing = false;
398                     let pair = pair.clone();
399 
400                     move |context| {
401                         let pair = pair.borrow();
402                         let (stream, _) = &pair.0.as_ref().unwrap();
403 
404                         loop {
405                             match stream.check_write() {
406                                 Ok(0) => {
407                                     WAKERS
408                                         .lock()
409                                         .unwrap()
410                                         .push((stream.subscribe(), context.waker().clone()));
411 
412                                     break Poll::Pending;
413                                 }
414                                 Ok(count) => {
415                                     if offset == chunk.len() {
416                                         if flushing {
417                                             break Poll::Ready(Ok(()));
418                                         } else {
419                                             stream.flush().expect("stream should be flushable");
420                                             flushing = true;
421                                         }
422                                     } else {
423                                         let count = usize::try_from(count)
424                                             .unwrap()
425                                             .min(chunk.len() - offset);
426 
427                                         match stream.write(&chunk[offset..][..count]) {
428                                             Ok(()) => {
429                                                 offset += count;
430                                             }
431                                             Err(_) => break Poll::Ready(Err(anyhow!("I/O error"))),
432                                         }
433                                     }
434                                 }
435                                 Err(_) => break Poll::Ready(Err(anyhow!("I/O error"))),
436                             }
437                         }
438                     }
439                 })
440             }
441         })
442     }
443 
444     pub fn outgoing_request_send(
445         request: OutgoingRequest,
446     ) -> impl Future<Output = Result<IncomingResponse, types::ErrorCode>> {
447         future::poll_fn({
448             let response = outgoing_handler::handle(request, None);
449 
450             move |context| match &response {
451                 Ok(response) => {
452                     if let Some(response) = response.get() {
453                         Poll::Ready(response.unwrap())
454                     } else {
455                         WAKERS
456                             .lock()
457                             .unwrap()
458                             .push((response.subscribe(), context.waker().clone()));
459                         Poll::Pending
460                     }
461                 }
462                 Err(error) => Poll::Ready(Err(error.clone())),
463             }
464         })
465     }
466 
467     pub fn incoming_body(body: IncomingBody) -> impl Stream<Item = Result<Vec<u8>>> {
468         enum Inner {
469             Stream {
470                 stream: InputStream,
471                 body: IncomingBody,
472             },
473             Trailers(FutureTrailers),
474             Closed,
475         }
476 
477         struct Incoming(Inner);
478 
479         impl Drop for Incoming {
480             fn drop(&mut self) {
481                 match mem::replace(&mut self.0, Inner::Closed) {
482                     Inner::Stream { stream, body } => {
483                         drop(stream);
484                         IncomingBody::finish(body);
485                     }
486                     Inner::Trailers(_) | Inner::Closed => {}
487                 }
488             }
489         }
490 
491         stream::poll_fn({
492             let stream = body.stream().expect("response body should be readable");
493             let mut incoming = Incoming(Inner::Stream { stream, body });
494 
495             move |context| {
496                 loop {
497                     match &incoming.0 {
498                         Inner::Stream { stream, .. } => match stream.read(READ_SIZE) {
499                             Ok(buffer) => {
500                                 return if buffer.is_empty() {
501                                     WAKERS
502                                         .lock()
503                                         .unwrap()
504                                         .push((stream.subscribe(), context.waker().clone()));
505                                     Poll::Pending
506                                 } else {
507                                     Poll::Ready(Some(Ok(buffer)))
508                                 };
509                             }
510                             Err(StreamError::Closed) => {
511                                 let Inner::Stream { stream, body } =
512                                     mem::replace(&mut incoming.0, Inner::Closed)
513                                 else {
514                                     unreachable!();
515                                 };
516                                 drop(stream);
517                                 incoming.0 = Inner::Trailers(IncomingBody::finish(body));
518                             }
519                             Err(StreamError::LastOperationFailed(error)) => {
520                                 return Poll::Ready(Some(Err(anyhow!(
521                                     "{}",
522                                     error.to_debug_string()
523                                 ))));
524                             }
525                         },
526 
527                         Inner::Trailers(trailers) => {
528                             match trailers.get() {
529                                 Some(Ok(trailers)) => {
530                                     incoming.0 = Inner::Closed;
531                                     match trailers {
532                                         Ok(Some(_)) => {
533                                             // Currently, we just ignore any trailers.  TODO: Add a test that
534                                             // expects trailers and verify they match the expected contents.
535                                         }
536                                         Ok(None) => {
537                                             // No trailers; nothing else to do.
538                                         }
539                                         Err(error) => {
540                                             // Error reading the trailers: pass it on to the application.
541                                             return Poll::Ready(Some(Err(anyhow!("{error:?}"))));
542                                         }
543                                     }
544                                 }
545                                 Some(Err(_)) => {
546                                     // Should only happen if we try to retrieve the trailers twice, i.e. a bug in
547                                     // this code.
548                                     unreachable!();
549                                 }
550                                 None => {
551                                     WAKERS
552                                         .lock()
553                                         .unwrap()
554                                         .push((trailers.subscribe(), context.waker().clone()));
555                                     return Poll::Pending;
556                                 }
557                             }
558                         }
559 
560                         Inner::Closed => {
561                             return Poll::Ready(None);
562                         }
563                     }
564                 }
565             }
566         })
567     }
568 }
569