xref: /tonic/interop/src/server.rs (revision 61555ff2)
1 use crate::pb::{self, *};
2 use async_stream::try_stream;
3 use futures_util::{stream, StreamExt, TryStreamExt};
4 use http::header::{HeaderMap, HeaderName, HeaderValue};
5 use http_body::Body;
6 use std::future::Future;
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 use std::time::Duration;
10 use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status};
11 use tower::Service;
12 
13 pub use pb::test_service_server::TestServiceServer;
14 pub use pb::unimplemented_service_server::UnimplementedServiceServer;
15 
16 #[derive(Default, Clone)]
17 pub struct TestService;
18 
19 type Result<T> = std::result::Result<Response<T>, Status>;
20 type Streaming<T> = Request<tonic::Streaming<T>>;
21 type Stream<T> = Pin<
22     Box<dyn futures_core::Stream<Item = std::result::Result<T, Status>> + Send + Sync + 'static>,
23 >;
24 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'static>>;
25 
26 #[tonic::async_trait]
27 impl pb::test_service_server::TestService for TestService {
28     async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> {
29         Ok(Response::new(Empty {}))
30     }
31 
32     async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> {
33         let req = request.into_inner();
34 
35         if let Some(echo_status) = req.response_status {
36             let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
37             return Err(status);
38         }
39 
40         let res_size = if req.response_size >= 0 {
41             req.response_size as usize
42         } else {
43             let status = Status::new(Code::InvalidArgument, "response_size cannot be negative");
44             return Err(status);
45         };
46 
47         let res = SimpleResponse {
48             payload: Some(Payload {
49                 body: vec![0; res_size],
50                 ..Default::default()
51             }),
52             ..Default::default()
53         };
54 
55         Ok(Response::new(res))
56     }
57 
58     async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> {
59         unimplemented!()
60     }
61 
62     type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>;
63 
64     async fn streaming_output_call(
65         &self,
66         req: Request<StreamingOutputCallRequest>,
67     ) -> Result<Self::StreamingOutputCallStream> {
68         let StreamingOutputCallRequest {
69             response_parameters,
70             ..
71         } = req.into_inner();
72 
73         let stream = try_stream! {
74             for param in response_parameters {
75                 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
76 
77                 let payload = crate::server_payload(param.size as usize);
78                 yield StreamingOutputCallResponse { payload: Some(payload) };
79             }
80         };
81 
82         Ok(Response::new(
83             Box::pin(stream) as Self::StreamingOutputCallStream
84         ))
85     }
86 
87     async fn streaming_input_call(
88         &self,
89         req: Streaming<StreamingInputCallRequest>,
90     ) -> Result<StreamingInputCallResponse> {
91         let mut stream = req.into_inner();
92 
93         let mut aggregated_payload_size = 0;
94         while let Some(msg) = stream.try_next().await? {
95             aggregated_payload_size += msg.payload.unwrap().body.len() as i32;
96         }
97 
98         let res = StreamingInputCallResponse {
99             aggregated_payload_size,
100         };
101 
102         Ok(Response::new(res))
103     }
104 
105     type FullDuplexCallStream = Stream<StreamingOutputCallResponse>;
106 
107     async fn full_duplex_call(
108         &self,
109         req: Streaming<StreamingOutputCallRequest>,
110     ) -> Result<Self::FullDuplexCallStream> {
111         let mut stream = req.into_inner();
112 
113         if let Some(first_msg) = stream.message().await? {
114             if let Some(echo_status) = first_msg.response_status {
115                 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
116                 return Err(status);
117             }
118 
119             let single_message = stream::iter(vec![Ok(first_msg)]);
120             let mut stream = single_message.chain(stream);
121 
122             let stream = try_stream! {
123                 while let Some(msg) = stream.try_next().await? {
124                     if let Some(echo_status) = msg.response_status {
125                         let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
126                         Err(status)?;
127                     }
128 
129                     for param in msg.response_parameters {
130                         tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
131 
132                         let payload = crate::server_payload(param.size as usize);
133                         yield StreamingOutputCallResponse { payload: Some(payload) };
134                     }
135                 }
136             };
137 
138             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
139         } else {
140             let stream = stream::empty();
141             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
142         }
143     }
144 
145     type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>;
146 
147     async fn half_duplex_call(
148         &self,
149         _: Streaming<StreamingOutputCallRequest>,
150     ) -> Result<Self::HalfDuplexCallStream> {
151         Err(Status::unimplemented("TODO"))
152     }
153 
154     async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> {
155         Err(Status::unimplemented(""))
156     }
157 }
158 
159 #[derive(Default)]
160 pub struct UnimplementedService;
161 
162 #[tonic::async_trait]
163 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService {
164     async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> {
165         Err(Status::unimplemented(""))
166     }
167 }
168 
169 #[derive(Clone, Default)]
170 pub struct EchoHeadersSvc<S> {
171     inner: S,
172 }
173 
174 impl<S: NamedService> NamedService for EchoHeadersSvc<S> {
175     const NAME: &'static str = S::NAME;
176 }
177 
178 impl<S> EchoHeadersSvc<S> {
179     pub fn new(inner: S) -> Self {
180         Self { inner }
181     }
182 }
183 
184 impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S>
185 where
186     S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send,
187     S::Future: Send + 'static,
188 {
189     type Response = S::Response;
190     type Error = S::Error;
191     type Future = BoxFuture<Self::Response, Self::Error>;
192 
193     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
194         Ok(()).into()
195     }
196 
197     fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
198         let echo_header = req
199             .headers()
200             .get("x-grpc-test-echo-initial")
201             .map(Clone::clone);
202 
203         let echo_trailer = req
204             .headers()
205             .get("x-grpc-test-echo-trailing-bin")
206             .map(Clone::clone)
207             .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v));
208 
209         let call = self.inner.call(req);
210 
211         Box::pin(async move {
212             let mut res = call.await?;
213 
214             if let Some(echo_header) = echo_header {
215                 res.headers_mut()
216                     .insert("x-grpc-test-echo-initial", echo_header);
217                 Ok(res
218                     .map(|b| MergeTrailers::new(b, echo_trailer))
219                     .map(BoxBody::new))
220             } else {
221                 Ok(res)
222             }
223         })
224     }
225 }
226 
227 pub struct MergeTrailers<B> {
228     inner: B,
229     trailer: Option<(HeaderName, HeaderValue)>,
230 }
231 
232 impl<B> MergeTrailers<B> {
233     pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self {
234         Self { inner, trailer }
235     }
236 }
237 
238 impl<B: Body + Unpin> Body for MergeTrailers<B> {
239     type Data = B::Data;
240     type Error = B::Error;
241 
242     fn poll_data(
243         mut self: Pin<&mut Self>,
244         cx: &mut Context<'_>,
245     ) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> {
246         Pin::new(&mut self.inner).poll_data(cx)
247     }
248 
249     fn poll_trailers(
250         mut self: Pin<&mut Self>,
251         cx: &mut Context<'_>,
252     ) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> {
253         Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| {
254             h.map(|mut headers| {
255                 if let Some((key, value)) = &self.trailer {
256                     headers.insert(key.clone(), value.clone());
257                 }
258 
259                 headers
260             })
261         })
262     }
263 }
264