1 use crate::pb::{self, *}; 2 use async_stream::try_stream; 3 use futures_util::{stream, StreamExt, TryStreamExt}; 4 use http::header::{HeaderMap, HeaderName, HeaderValue}; 5 use http_body::Body; 6 use std::future::Future; 7 use std::pin::Pin; 8 use std::task::{Context, Poll}; 9 use std::time::Duration; 10 use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status}; 11 use tower::Service; 12 13 pub use pb::test_service_server::TestServiceServer; 14 pub use pb::unimplemented_service_server::UnimplementedServiceServer; 15 16 #[derive(Default, Clone)] 17 pub struct TestService; 18 19 type Result<T> = std::result::Result<Response<T>, Status>; 20 type Streaming<T> = Request<tonic::Streaming<T>>; 21 type Stream<T> = Pin< 22 Box<dyn futures_core::Stream<Item = std::result::Result<T, Status>> + Send + Sync + 'static>, 23 >; 24 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'static>>; 25 26 #[tonic::async_trait] 27 impl pb::test_service_server::TestService for TestService { 28 async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> { 29 Ok(Response::new(Empty {})) 30 } 31 32 async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> { 33 let req = request.into_inner(); 34 35 if let Some(echo_status) = req.response_status { 36 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 37 return Err(status); 38 } 39 40 let res_size = if req.response_size >= 0 { 41 req.response_size as usize 42 } else { 43 let status = Status::new(Code::InvalidArgument, "response_size cannot be negative"); 44 return Err(status); 45 }; 46 47 let res = SimpleResponse { 48 payload: Some(Payload { 49 body: vec![0; res_size], 50 ..Default::default() 51 }), 52 ..Default::default() 53 }; 54 55 Ok(Response::new(res)) 56 } 57 58 async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> { 59 unimplemented!() 60 } 61 62 type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>; 63 64 async fn streaming_output_call( 65 &self, 66 req: Request<StreamingOutputCallRequest>, 67 ) -> Result<Self::StreamingOutputCallStream> { 68 let StreamingOutputCallRequest { 69 response_parameters, 70 .. 71 } = req.into_inner(); 72 73 let stream = try_stream! { 74 for param in response_parameters { 75 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await; 76 77 let payload = crate::server_payload(param.size as usize); 78 yield StreamingOutputCallResponse { payload: Some(payload) }; 79 } 80 }; 81 82 Ok(Response::new( 83 Box::pin(stream) as Self::StreamingOutputCallStream 84 )) 85 } 86 87 async fn streaming_input_call( 88 &self, 89 req: Streaming<StreamingInputCallRequest>, 90 ) -> Result<StreamingInputCallResponse> { 91 let mut stream = req.into_inner(); 92 93 let mut aggregated_payload_size = 0; 94 while let Some(msg) = stream.try_next().await? { 95 aggregated_payload_size += msg.payload.unwrap().body.len() as i32; 96 } 97 98 let res = StreamingInputCallResponse { 99 aggregated_payload_size, 100 }; 101 102 Ok(Response::new(res)) 103 } 104 105 type FullDuplexCallStream = Stream<StreamingOutputCallResponse>; 106 107 async fn full_duplex_call( 108 &self, 109 req: Streaming<StreamingOutputCallRequest>, 110 ) -> Result<Self::FullDuplexCallStream> { 111 let mut stream = req.into_inner(); 112 113 if let Some(first_msg) = stream.message().await? { 114 if let Some(echo_status) = first_msg.response_status { 115 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 116 return Err(status); 117 } 118 119 let single_message = stream::iter(vec![Ok(first_msg)]); 120 let mut stream = single_message.chain(stream); 121 122 let stream = try_stream! { 123 while let Some(msg) = stream.try_next().await? { 124 if let Some(echo_status) = msg.response_status { 125 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 126 Err(status)?; 127 } 128 129 for param in msg.response_parameters { 130 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await; 131 132 let payload = crate::server_payload(param.size as usize); 133 yield StreamingOutputCallResponse { payload: Some(payload) }; 134 } 135 } 136 }; 137 138 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 139 } else { 140 let stream = stream::empty(); 141 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 142 } 143 } 144 145 type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>; 146 147 async fn half_duplex_call( 148 &self, 149 _: Streaming<StreamingOutputCallRequest>, 150 ) -> Result<Self::HalfDuplexCallStream> { 151 Err(Status::unimplemented("TODO")) 152 } 153 154 async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> { 155 Err(Status::unimplemented("")) 156 } 157 } 158 159 #[derive(Default)] 160 pub struct UnimplementedService; 161 162 #[tonic::async_trait] 163 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService { 164 async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> { 165 Err(Status::unimplemented("")) 166 } 167 } 168 169 #[derive(Clone, Default)] 170 pub struct EchoHeadersSvc<S> { 171 inner: S, 172 } 173 174 impl<S: NamedService> NamedService for EchoHeadersSvc<S> { 175 const NAME: &'static str = S::NAME; 176 } 177 178 impl<S> EchoHeadersSvc<S> { 179 pub fn new(inner: S) -> Self { 180 Self { inner } 181 } 182 } 183 184 impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S> 185 where 186 S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send, 187 S::Future: Send + 'static, 188 { 189 type Response = S::Response; 190 type Error = S::Error; 191 type Future = BoxFuture<Self::Response, Self::Error>; 192 193 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> { 194 Ok(()).into() 195 } 196 197 fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future { 198 let echo_header = req 199 .headers() 200 .get("x-grpc-test-echo-initial") 201 .map(Clone::clone); 202 203 let echo_trailer = req 204 .headers() 205 .get("x-grpc-test-echo-trailing-bin") 206 .map(Clone::clone) 207 .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v)); 208 209 let call = self.inner.call(req); 210 211 Box::pin(async move { 212 let mut res = call.await?; 213 214 if let Some(echo_header) = echo_header { 215 res.headers_mut() 216 .insert("x-grpc-test-echo-initial", echo_header); 217 Ok(res 218 .map(|b| MergeTrailers::new(b, echo_trailer)) 219 .map(BoxBody::new)) 220 } else { 221 Ok(res) 222 } 223 }) 224 } 225 } 226 227 pub struct MergeTrailers<B> { 228 inner: B, 229 trailer: Option<(HeaderName, HeaderValue)>, 230 } 231 232 impl<B> MergeTrailers<B> { 233 pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self { 234 Self { inner, trailer } 235 } 236 } 237 238 impl<B: Body + Unpin> Body for MergeTrailers<B> { 239 type Data = B::Data; 240 type Error = B::Error; 241 242 fn poll_data( 243 mut self: Pin<&mut Self>, 244 cx: &mut Context<'_>, 245 ) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> { 246 Pin::new(&mut self.inner).poll_data(cx) 247 } 248 249 fn poll_trailers( 250 mut self: Pin<&mut Self>, 251 cx: &mut Context<'_>, 252 ) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> { 253 Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { 254 h.map(|mut headers| { 255 if let Some((key, value)) = &self.trailer { 256 headers.insert(key.clone(), value.clone()); 257 } 258 259 headers 260 }) 261 }) 262 } 263 } 264