1 use anyhow::Result;
2 use futures::{Sink, Stream, channel::oneshot};
3 use std::{
4     marker::PhantomData,
5     pin::Pin,
6     task::{Context, Poll},
7     thread,
8 };
9 use wasmtime::{
10     StoreContextMut,
11     component::{
12         Destination, FutureConsumer, FutureProducer, Lift, Lower, Source, StreamConsumer,
13         StreamProducer, StreamResult,
14     },
15 };
16 
17 pub async fn sleep(duration: std::time::Duration) {
18     if cfg!(miri) {
19         // TODO: We should be able to use `tokio::time::sleep` here, but as of
20         // this writing the miri-compatible version of `wasmtime-fiber` uses
21         // threads behind the scenes, which means thread-local storage is not
22         // preserved when we switch fibers, and that confuses Tokio.  If we ever
23         // fix that we can stop using our own, special version of `sleep` and
24         // switch back to the Tokio version.
25 
26         let (tx, rx) = oneshot::channel();
27         let handle = thread::spawn(move || {
28             thread::sleep(duration);
29             _ = tx.send(());
30         });
31         _ = rx.await;
32         _ = handle.join();
33     } else {
34         tokio::time::sleep(duration).await;
35     }
36 }
37 
38 pub struct PipeProducer<S>(S);
39 
40 impl<S> PipeProducer<S> {
41     pub fn new(rx: S) -> Self {
42         Self(rx)
43     }
44 }
45 
46 impl<D, T: Send + Sync + Lower + 'static, S: Stream<Item = T> + Send + 'static> StreamProducer<D>
47     for PipeProducer<S>
48 {
49     type Item = T;
50     type Buffer = Option<T>;
51 
52     fn poll_produce<'a>(
53         self: Pin<&mut Self>,
54         cx: &mut Context<'_>,
55         _: StoreContextMut<D>,
56         mut destination: Destination<'a, Self::Item, Self::Buffer>,
57         finish: bool,
58     ) -> Poll<Result<StreamResult>> {
59         // SAFETY: This is a standard pin-projection, and we never move
60         // out of `self`.
61         let stream = unsafe { self.map_unchecked_mut(|v| &mut v.0) };
62 
63         match stream.poll_next(cx) {
64             Poll::Pending => {
65                 if finish {
66                     Poll::Ready(Ok(StreamResult::Cancelled))
67                 } else {
68                     Poll::Pending
69                 }
70             }
71             Poll::Ready(Some(item)) => {
72                 destination.set_buffer(Some(item));
73                 Poll::Ready(Ok(StreamResult::Completed))
74             }
75             Poll::Ready(None) => Poll::Ready(Ok(StreamResult::Dropped)),
76         }
77     }
78 }
79 
80 pub struct PipeConsumer<T, S>(S, PhantomData<fn() -> T>);
81 
82 impl<T, S> PipeConsumer<T, S> {
83     pub fn new(tx: S) -> Self {
84         Self(tx, PhantomData)
85     }
86 }
87 
88 impl<D, T: Lift + 'static, S: Sink<T, Error: std::error::Error + Send + Sync> + Send + 'static>
89     StreamConsumer<D> for PipeConsumer<T, S>
90 {
91     type Item = T;
92 
93     fn poll_consume(
94         self: Pin<&mut Self>,
95         cx: &mut Context<'_>,
96         store: StoreContextMut<D>,
97         mut source: Source<Self::Item>,
98         finish: bool,
99     ) -> Poll<Result<StreamResult>> {
100         // SAFETY: This is a standard pin-projection, and we never move
101         // out of `self`.
102         let mut sink = unsafe { self.map_unchecked_mut(|v| &mut v.0) };
103 
104         let on_pending = || {
105             if finish {
106                 Poll::Ready(Ok(StreamResult::Cancelled))
107             } else {
108                 Poll::Pending
109             }
110         };
111 
112         match sink.as_mut().poll_flush(cx) {
113             Poll::Pending => on_pending(),
114             Poll::Ready(result) => {
115                 result?;
116                 match sink.as_mut().poll_ready(cx) {
117                     Poll::Pending => on_pending(),
118                     Poll::Ready(result) => {
119                         result?;
120                         let item = &mut None;
121                         source.read(store, item)?;
122                         sink.start_send(item.take().unwrap())?;
123                         Poll::Ready(Ok(StreamResult::Completed))
124                     }
125                 }
126             }
127         }
128     }
129 }
130 
131 pub struct OneshotProducer<T>(oneshot::Receiver<T>);
132 
133 impl<T> OneshotProducer<T> {
134     pub fn new(rx: oneshot::Receiver<T>) -> Self {
135         Self(rx)
136     }
137 }
138 
139 impl<D, T: Send + 'static> FutureProducer<D> for OneshotProducer<T> {
140     type Item = T;
141 
142     fn poll_produce(
143         self: Pin<&mut Self>,
144         cx: &mut Context<'_>,
145         _: StoreContextMut<D>,
146         finish: bool,
147     ) -> Poll<Result<Option<T>>> {
148         match Pin::new(&mut self.get_mut().0).poll(cx) {
149             Poll::Pending if finish => Poll::Ready(Ok(None)),
150             Poll::Pending => Poll::Pending,
151             Poll::Ready(result) => Poll::Ready(Ok(Some(result?))),
152         }
153     }
154 }
155 
156 pub struct OneshotConsumer<T>(Option<oneshot::Sender<T>>);
157 
158 impl<T> OneshotConsumer<T> {
159     pub fn new(tx: oneshot::Sender<T>) -> Self {
160         Self(Some(tx))
161     }
162 }
163 
164 impl<D, T: Lift + Send + 'static> FutureConsumer<D> for OneshotConsumer<T> {
165     type Item = T;
166 
167     fn poll_consume(
168         self: Pin<&mut Self>,
169         _: &mut Context<'_>,
170         store: StoreContextMut<D>,
171         mut source: Source<'_, T>,
172         _: bool,
173     ) -> Poll<Result<()>> {
174         let value = &mut None;
175         source.read(store, value)?;
176         _ = self.get_mut().0.take().unwrap().send(value.take().unwrap());
177         Poll::Ready(Ok(()))
178     }
179 }
180