xref: /tonic/interop/src/server.rs (revision da92dbf8)
1 use crate::pb::{self, *};
2 use async_stream::try_stream;
3 use futures_util::{stream, StreamExt, TryStreamExt};
4 use http::header::{HeaderMap, HeaderName, HeaderValue};
5 use http_body::Body;
6 use std::future::Future;
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 use std::time::Duration;
10 use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status};
11 use tower::Service;
12 
13 pub use pb::test_service_server::TestServiceServer;
14 pub use pb::unimplemented_service_server::UnimplementedServiceServer;
15 
16 #[derive(Default, Clone)]
17 pub struct TestService;
18 
19 type Result<T> = std::result::Result<Response<T>, Status>;
20 type Streaming<T> = Request<tonic::Streaming<T>>;
21 type Stream<T> = Pin<
22     Box<dyn futures_core::Stream<Item = std::result::Result<T, Status>> + Send + Sync + 'static>,
23 >;
24 
25 #[tonic::async_trait]
26 impl pb::test_service_server::TestService for TestService {
27     async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> {
28         Ok(Response::new(Empty {}))
29     }
30 
31     async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> {
32         let req = request.into_inner();
33 
34         if let Some(echo_status) = req.response_status {
35             let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
36             return Err(status);
37         }
38 
39         let res_size = if req.response_size >= 0 {
40             req.response_size as usize
41         } else {
42             let status = Status::new(Code::InvalidArgument, "response_size cannot be negative");
43             return Err(status);
44         };
45 
46         let res = SimpleResponse {
47             payload: Some(Payload {
48                 body: vec![0; res_size],
49                 ..Default::default()
50             }),
51             ..Default::default()
52         };
53 
54         Ok(Response::new(res))
55     }
56 
57     async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> {
58         unimplemented!()
59     }
60 
61     type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>;
62 
63     async fn streaming_output_call(
64         &self,
65         req: Request<StreamingOutputCallRequest>,
66     ) -> Result<Self::StreamingOutputCallStream> {
67         let StreamingOutputCallRequest {
68             response_parameters,
69             ..
70         } = req.into_inner();
71 
72         let stream = try_stream! {
73             for param in response_parameters {
74                 tokio::time::delay_for(Duration::from_micros(param.interval_us as u64)).await;
75 
76                 let payload = crate::server_payload(param.size as usize);
77                 yield StreamingOutputCallResponse { payload: Some(payload) };
78             }
79         };
80 
81         Ok(Response::new(
82             Box::pin(stream) as Self::StreamingOutputCallStream
83         ))
84     }
85 
86     async fn streaming_input_call(
87         &self,
88         req: Streaming<StreamingInputCallRequest>,
89     ) -> Result<StreamingInputCallResponse> {
90         let mut stream = req.into_inner();
91 
92         let mut aggregated_payload_size = 0 as i32;
93         while let Some(msg) = stream.try_next().await? {
94             aggregated_payload_size += msg.payload.unwrap().body.len() as i32;
95         }
96 
97         let res = StreamingInputCallResponse {
98             aggregated_payload_size,
99         };
100 
101         Ok(Response::new(res))
102     }
103 
104     type FullDuplexCallStream = Stream<StreamingOutputCallResponse>;
105 
106     async fn full_duplex_call(
107         &self,
108         req: Streaming<StreamingOutputCallRequest>,
109     ) -> Result<Self::FullDuplexCallStream> {
110         let mut stream = req.into_inner();
111 
112         if let Some(first_msg) = stream.message().await? {
113             if let Some(echo_status) = first_msg.response_status {
114                 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
115                 return Err(status);
116             }
117 
118             let single_message = stream::iter(vec![Ok(first_msg)]);
119             let mut stream = single_message.chain(stream);
120 
121             let stream = try_stream! {
122                 while let Some(msg) = stream.try_next().await? {
123                     if let Some(echo_status) = msg.response_status {
124                         let status = Status::new(Code::from_i32(echo_status.code), echo_status.message);
125                         Err(status)?;
126                     }
127 
128                     for param in msg.response_parameters {
129                         tokio::time::delay_for(Duration::from_micros(param.interval_us as u64)).await;
130 
131                         let payload = crate::server_payload(param.size as usize);
132                         yield StreamingOutputCallResponse { payload: Some(payload) };
133                     }
134                 }
135             };
136 
137             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
138         } else {
139             let stream = stream::empty();
140             Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream))
141         }
142     }
143 
144     type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>;
145 
146     async fn half_duplex_call(
147         &self,
148         _: Streaming<StreamingOutputCallRequest>,
149     ) -> Result<Self::HalfDuplexCallStream> {
150         Err(Status::unimplemented("TODO"))
151     }
152 
153     async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> {
154         Err(Status::unimplemented(""))
155     }
156 }
157 
158 #[derive(Default)]
159 pub struct UnimplementedService;
160 
161 #[tonic::async_trait]
162 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService {
163     async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> {
164         Err(Status::unimplemented(""))
165     }
166 }
167 
168 #[derive(Clone, Default)]
169 pub struct EchoHeadersSvc<S> {
170     inner: S,
171 }
172 
173 impl<S: NamedService> NamedService for EchoHeadersSvc<S> {
174     const NAME: &'static str = S::NAME;
175 }
176 
177 impl<S> EchoHeadersSvc<S> {
178     pub fn new(inner: S) -> Self {
179         Self { inner }
180     }
181 }
182 
183 impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S>
184 where
185     S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send,
186     S::Future: Send + 'static,
187 {
188     type Response = S::Response;
189     type Error = S::Error;
190     type Future = Pin<
191         Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send + 'static>,
192     >;
193 
194     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
195         Ok(()).into()
196     }
197 
198     fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future {
199         let echo_header = req
200             .headers()
201             .get("x-grpc-test-echo-initial")
202             .map(Clone::clone);
203 
204         let echo_trailer = req
205             .headers()
206             .get("x-grpc-test-echo-trailing-bin")
207             .map(Clone::clone)
208             .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v));
209 
210         let call = self.inner.call(req);
211 
212         Box::pin(async move {
213             let mut res = call.await?;
214 
215             if let Some(echo_header) = echo_header {
216                 res.headers_mut()
217                     .insert("x-grpc-test-echo-initial", echo_header);
218             }
219 
220             Ok(res
221                 .map(|b| MergeTrailers::new(b, echo_trailer))
222                 .map(BoxBody::new))
223         })
224     }
225 }
226 
227 pub struct MergeTrailers<B> {
228     inner: B,
229     trailer: Option<(HeaderName, HeaderValue)>,
230 }
231 
232 impl<B> MergeTrailers<B> {
233     pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self {
234         Self { inner, trailer }
235     }
236 }
237 
238 impl<B: Body + Unpin> Body for MergeTrailers<B> {
239     type Data = B::Data;
240     type Error = B::Error;
241 
242     fn poll_data(
243         mut self: Pin<&mut Self>,
244         cx: &mut Context<'_>,
245     ) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> {
246         Pin::new(&mut self.inner).poll_data(cx)
247     }
248 
249     fn poll_trailers(
250         mut self: Pin<&mut Self>,
251         cx: &mut Context<'_>,
252     ) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> {
253         Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| {
254             h.map(|mut headers| {
255                 if let Some((key, value)) = &self.trailer {
256                     headers.insert(key.clone(), value.clone());
257                 }
258 
259                 headers
260             })
261         })
262     }
263 }
264