xref: /tonic/interop/src/server.rs (revision 5c7a69ec)
1 use crate::pb::{self, *};
2 use async_stream::try_stream;
3 use http::header::{HeaderMap, HeaderName};
4 use http_body_util::BodyExt;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::result::Result as StdResult;
8 use std::task::{Context, Poll};
9 use std::time::Duration;
10 use tokio_stream::StreamExt;
11 use tonic::{body::Body, server::NamedService, Code, Request, Response, Status};
12 use tower::Service;
13 
14 pub use pb::test_service_server::TestServiceServer;
15 pub use pb::unimplemented_service_server::UnimplementedServiceServer;
16 
17 #[derive(Default, Clone)]
18 pub struct TestService {}
19 
20 type Result<T> = StdResult<Response<T>, Status>;
21 type Streaming<T> = Request<tonic::Streaming<T>>;
22 type Stream<T> = Pin<Box<dyn tokio_stream::Stream<Item = StdResult<T, Status>> + Send + 'static>>;
23 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = StdResult<T, E>> + Send + 'static>>;
24 
25 #[tonic::async_trait]
26 impl pb::test_service_server::TestService for TestService {
empty_call(&self, _request: Request<Empty>) -> Result<Empty>27     async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> {
28         Ok(Response::new(Empty {}))
29     }
30 
unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse>31     async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> {
32         let req = request.into_inner();
33 
34         if let Some(echo_status) = req.response_status {
35             let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
36             return Err(status);
37         }
38 
39         let res_size = if req.response_size >= 0 {
40             req.response_size as usize
41         } else {
42             let status = Status::new(Code::InvalidArgument, "response_size cannot be negative");
43             return Err(status);
44         };
45 
46         let res = SimpleResponse {
47             payload: Some(Payload {
48                 body: vec![0; res_size],
49                 ..Default::default()
50             }),
51             ..Default::default()
52         };
53 
54         Ok(Response::new(res))
55     }
56 
cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse>57     async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> {
58         unimplemented!()
59     }
60 
61     type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>;
62 
streaming_output_call( &self, req: Request<StreamingOutputCallRequest>, ) -> Result<Self::StreamingOutputCallStream>63     async fn streaming_output_call(
64         &self,
65         req: Request<StreamingOutputCallRequest>,
66     ) -> Result<Self::StreamingOutputCallStream> {
67         let StreamingOutputCallRequest {
68             response_parameters,
69             ..
70         } = req.into_inner();
71 
72         let stream = try_stream! {
73             for param in response_parameters {
74                 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
75 
76                 let payload = crate::server_payload(param.size as usize);
77                 yield StreamingOutputCallResponse { payload: Some(payload) };
78             }
79         };
80 
81         Ok(Response::new(
82             Box::pin(stream) as Self::StreamingOutputCallStream
83         ))
84     }
85 
streaming_input_call( &self, req: Streaming<StreamingInputCallRequest>, ) -> Result<StreamingInputCallResponse>86     async fn streaming_input_call(
87         &self,
88         req: Streaming<StreamingInputCallRequest>,
89     ) -> Result<StreamingInputCallResponse> {
90         let mut stream = req.into_inner();
91 
92         let mut aggregated_payload_size = 0;
93         while let Some(msg) = stream.try_next().await? {
94             aggregated_payload_size += msg.payload.unwrap().body.len() as i32;
95         }
96 
97         let res = StreamingInputCallResponse {
98             aggregated_payload_size,
99         };
100 
101         Ok(Response::new(res))
102     }
103 
104     type FullDuplexCallStream = Stream<StreamingOutputCallResponse>;
105 
full_duplex_call( &self, req: Streaming<StreamingOutputCallRequest>, ) -> Result<Self::FullDuplexCallStream>106     async fn full_duplex_call(
107         &self,
108         req: Streaming<StreamingOutputCallRequest>,
109     ) -> Result<Self::FullDuplexCallStream> {
110         let mut stream = req.into_inner();
111 
112         if let Some(first_msg) = stream.message().await? {
113             if let Some(echo_status) = first_msg.response_status {
114                 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
115                 return Err(status);
116             }
117 
118             let single_message = tokio_stream::once(Ok(first_msg));
119             let mut stream = single_message.chain(stream);
120 
121             let stream = try_stream! {
122                 while let Some(msg) = stream.try_next().await? {
123                     if let Some(echo_status) = msg.response_status {
124                         let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
125                         Err(status)?;
126                     }
127 
128                     for param in msg.response_parameters {
129                         tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await;
130 
131                         let payload = crate::server_payload(param.size as usize);
132                         yield StreamingOutputCallResponse { payload: Some(payload) };
133                     }
134                 }
135             };
136 
137             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
138         } else {
139             let stream = tokio_stream::empty();
140             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
141         }
142     }
143 
144     type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>;
145 
half_duplex_call( &self, _: Streaming<StreamingOutputCallRequest>, ) -> Result<Self::HalfDuplexCallStream>146     async fn half_duplex_call(
147         &self,
148         _: Streaming<StreamingOutputCallRequest>,
149     ) -> Result<Self::HalfDuplexCallStream> {
150         Err(Status::unimplemented("TODO"))
151     }
152 
unimplemented_call(&self, _: Request<Empty>) -> Result<Empty>153     async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> {
154         Err(Status::unimplemented(""))
155     }
156 }
157 
158 #[derive(Default)]
159 pub struct UnimplementedService {}
160 
161 #[tonic::async_trait]
162 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService {
unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty>163     async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> {
164         Err(Status::unimplemented(""))
165     }
166 }
167 
168 #[derive(Clone, Default)]
169 pub struct EchoHeadersSvc<S> {
170     inner: S,
171 }
172 
173 impl<S: NamedService> NamedService for EchoHeadersSvc<S> {
174     const NAME: &'static str = S::NAME;
175 }
176 
177 impl<S> EchoHeadersSvc<S> {
new(inner: S) -> Self178     pub fn new(inner: S) -> Self {
179         Self { inner }
180     }
181 }
182 
183 impl<S> Service<http::Request<Body>> for EchoHeadersSvc<S>
184 where
185     S: Service<http::Request<Body>, Response = http::Response<Body>> + Send,
186     S::Future: Send + 'static,
187 {
188     type Response = S::Response;
189     type Error = S::Error;
190     type Future = BoxFuture<Self::Response, Self::Error>;
191 
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>>192     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> {
193         Ok(()).into()
194     }
195 
call(&mut self, req: http::Request<Body>) -> Self::Future196     fn call(&mut self, req: http::Request<Body>) -> Self::Future {
197         let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned();
198 
199         let trailer_name = HeaderName::from_static("x-grpc-test-echo-trailing-bin");
200         let echo_trailer = req
201             .headers()
202             .get(&trailer_name)
203             .cloned()
204             .map(|v| HeaderMap::from_iter(std::iter::once((trailer_name, v))));
205 
206         let call = self.inner.call(req);
207 
208         Box::pin(async move {
209             let mut res = call.await?;
210 
211             if let Some(echo_header) = echo_header {
212                 res.headers_mut()
213                     .insert("x-grpc-test-echo-initial", echo_header);
214                 Ok(res
215                     .map(|b| b.with_trailers(async move { echo_trailer.map(Ok) }))
216                     .map(Body::new))
217             } else {
218                 Ok(res)
219             }
220         })
221     }
222 }
223