xref: /tonic/interop/src/server.rs (revision ff71e893)
1 use crate::pb::{self, *};
2 use async_stream::try_stream;
3 use http::header::{HeaderMap, HeaderName, HeaderValue};
4 use http_body::Body;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 use std::time::Duration;
9 use tokio_stream::StreamExt;
10 use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status};
11 use tower::Service;
12 
13 pub use pb::test_service_server::TestServiceServer;
14 pub use pb::unimplemented_service_server::UnimplementedServiceServer;
15 
16 #[derive(Default, Clone)]
17 pub struct TestService;
18 
19 type Result<T> = std::result::Result<Response<T>, Status>;
20 type Streaming<T> = Request<tonic::Streaming<T>>;
21 type Stream<T> =
22     Pin<Box<dyn tokio_stream::Stream<Item = std::result::Result<T, Status>> + Send + 'static>>;
23 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'static>>;
24 
25 #[tonic::async_trait]
26 impl pb::test_service_server::TestService for TestService {
27     async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> {
28         Ok(Response::new(Empty {}))
29     }
30 
31     async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> {
32         let req = request.into_inner();
33 
34         if let Some(echo_status) = req.response_status {
35             let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
36             return Err(status);
37         }
38 
39         let res_size = if req.response_size >= 0 {
40             req.response_size as usize
41         } else {
42             let status = Status::new(Code::InvalidArgument, "response_size cannot be negative");
43             return Err(status);
44         };
45 
46         let res = SimpleResponse {
47             payload: Some(Payload {
48                 body: vec![0; res_size],
49                 ..Default::default()
50             }),
51             ..Default::default()
52         };
53 
54         Ok(Response::new(res))
55     }
56 
57     async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> {
58         unimplemented!()
59     }
60 
61     type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>;
62 
63     async fn streaming_output_call(
64         &self,
65         req: Request<StreamingOutputCallRequest>,
66     ) -> Result<Self::StreamingOutputCallStream> {
67         let StreamingOutputCallRequest {
68             response_parameters,
69             ..
70         } = req.into_inner();
71 
72         let stream = try_stream! {
73             for param in response_parameters {
74                 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
75 
76                 let payload = crate::server_payload(param.size as usize);
77                 yield StreamingOutputCallResponse { payload: Some(payload) };
78             }
79         };
80 
81         Ok(Response::new(
82             Box::pin(stream) as Self::StreamingOutputCallStream
83         ))
84     }
85 
86     async fn streaming_input_call(
87         &self,
88         req: Streaming<StreamingInputCallRequest>,
89     ) -> Result<StreamingInputCallResponse> {
90         let mut stream = req.into_inner();
91 
92         let mut aggregated_payload_size = 0;
93         while let Some(msg) = stream.try_next().await? {
94             aggregated_payload_size += msg.payload.unwrap().body.len() as i32;
95         }
96 
97         let res = StreamingInputCallResponse {
98             aggregated_payload_size,
99         };
100 
101         Ok(Response::new(res))
102     }
103 
104     type FullDuplexCallStream = Stream<StreamingOutputCallResponse>;
105 
106     async fn full_duplex_call(
107         &self,
108         req: Streaming<StreamingOutputCallRequest>,
109     ) -> Result<Self::FullDuplexCallStream> {
110         let mut stream = req.into_inner();
111 
112         if let Some(first_msg) = stream.message().await? {
113             if let Some(echo_status) = first_msg.response_status {
114                 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
115                 return Err(status);
116             }
117 
118             let single_message = tokio_stream::once(Ok(first_msg));
119             let mut stream = single_message.chain(stream);
120 
121             let stream = try_stream! {
122                 while let Some(msg) = stream.try_next().await? {
123                     if let Some(echo_status) = msg.response_status {
124                         let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
125                         Err(status)?;
126                     }
127 
128                     for param in msg.response_parameters {
129                         tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
130 
131                         let payload = crate::server_payload(param.size as usize);
132                         yield StreamingOutputCallResponse { payload: Some(payload) };
133                     }
134                 }
135             };
136 
137             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
138         } else {
139             let stream = tokio_stream::empty();
140             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
141         }
142     }
143 
144     type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>;
145 
146     async fn half_duplex_call(
147         &self,
148         _: Streaming<StreamingOutputCallRequest>,
149     ) -> Result<Self::HalfDuplexCallStream> {
150         Err(Status::unimplemented("TODO"))
151     }
152 
153     async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> {
154         Err(Status::unimplemented(""))
155     }
156 }
157 
158 #[derive(Default)]
159 pub struct UnimplementedService;
160 
161 #[tonic::async_trait]
162 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService {
163     async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> {
164         Err(Status::unimplemented(""))
165     }
166 }
167 
168 #[derive(Clone, Default)]
169 pub struct EchoHeadersSvc<S> {
170     inner: S,
171 }
172 
173 impl<S: NamedService> NamedService for EchoHeadersSvc<S> {
174     const NAME: &'static str = S::NAME;
175 }
176 
177 impl<S> EchoHeadersSvc<S> {
178     pub fn new(inner: S) -> Self {
179         Self { inner }
180     }
181 }
182 
183 impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S>
184 where
185     S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send,
186     S::Future: Send + 'static,
187 {
188     type Response = S::Response;
189     type Error = S::Error;
190     type Future = BoxFuture<Self::Response, Self::Error>;
191 
192     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
193         Ok(()).into()
194     }
195 
196     fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
197         let echo_header = req
198             .headers()
199             .get("x-grpc-test-echo-initial")
200             .map(Clone::clone);
201 
202         let echo_trailer = req
203             .headers()
204             .get("x-grpc-test-echo-trailing-bin")
205             .map(Clone::clone)
206             .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v));
207 
208         let call = self.inner.call(req);
209 
210         Box::pin(async move {
211             let mut res = call.await?;
212 
213             if let Some(echo_header) = echo_header {
214                 res.headers_mut()
215                     .insert("x-grpc-test-echo-initial", echo_header);
216                 Ok(res
217                     .map(|b| MergeTrailers::new(b, echo_trailer))
218                     .map(BoxBody::new))
219             } else {
220                 Ok(res)
221             }
222         })
223     }
224 }
225 
226 pub struct MergeTrailers<B> {
227     inner: B,
228     trailer: Option<(HeaderName, HeaderValue)>,
229 }
230 
231 impl<B> MergeTrailers<B> {
232     pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self {
233         Self { inner, trailer }
234     }
235 }
236 
237 impl<B: Body + Unpin> Body for MergeTrailers<B> {
238     type Data = B::Data;
239     type Error = B::Error;
240 
241     fn poll_data(
242         mut self: Pin<&mut Self>,
243         cx: &mut Context<'_>,
244     ) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> {
245         Pin::new(&mut self.inner).poll_data(cx)
246     }
247 
248     fn poll_trailers(
249         mut self: Pin<&mut Self>,
250         cx: &mut Context<'_>,
251     ) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> {
252         Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| {
253             h.map(|mut headers| {
254                 if let Some((key, value)) = &self.trailer {
255                     headers.insert(key.clone(), value.clone());
256                 }
257 
258                 headers
259             })
260         })
261     }
262 }
263