1 use crate::pb::{self, *}; 2 use async_stream::try_stream; 3 use http::header::{HeaderMap, HeaderName}; 4 use http_body_util::BodyExt; 5 use std::future::Future; 6 use std::pin::Pin; 7 use std::result::Result as StdResult; 8 use std::task::{Context, Poll}; 9 use std::time::Duration; 10 use tokio_stream::StreamExt; 11 use tonic::{body::Body, server::NamedService, Code, Request, Response, Status}; 12 use tower::Service; 13 14 pub use pb::test_service_server::TestServiceServer; 15 pub use pb::unimplemented_service_server::UnimplementedServiceServer; 16 17 #[derive(Default, Clone)] 18 pub struct TestService {} 19 20 type Result<T> = StdResult<Response<T>, Status>; 21 type Streaming<T> = Request<tonic::Streaming<T>>; 22 type Stream<T> = Pin<Box<dyn tokio_stream::Stream<Item = StdResult<T, Status>> + Send + 'static>>; 23 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = StdResult<T, E>> + Send + 'static>>; 24 25 #[tonic::async_trait] 26 impl pb::test_service_server::TestService for TestService { empty_call(&self, _request: Request<Empty>) -> Result<Empty>27 async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> { 28 Ok(Response::new(Empty {})) 29 } 30 unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse>31 async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> { 32 let req = request.into_inner(); 33 34 if let Some(echo_status) = req.response_status { 35 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 36 return Err(status); 37 } 38 39 let res_size = if req.response_size >= 0 { 40 req.response_size as usize 41 } else { 42 let status = Status::new(Code::InvalidArgument, "response_size cannot be negative"); 43 return Err(status); 44 }; 45 46 let res = SimpleResponse { 47 payload: Some(Payload { 48 body: vec![0; res_size], 49 ..Default::default() 50 }), 51 ..Default::default() 52 }; 53 54 Ok(Response::new(res)) 55 } 56 cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse>57 async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> { 58 unimplemented!() 59 } 60 61 type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>; 62 streaming_output_call( &self, req: Request<StreamingOutputCallRequest>, ) -> Result<Self::StreamingOutputCallStream>63 async fn streaming_output_call( 64 &self, 65 req: Request<StreamingOutputCallRequest>, 66 ) -> Result<Self::StreamingOutputCallStream> { 67 let StreamingOutputCallRequest { 68 response_parameters, 69 .. 70 } = req.into_inner(); 71 72 let stream = try_stream! { 73 for param in response_parameters { 74 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await; 75 76 let payload = crate::server_payload(param.size as usize); 77 yield StreamingOutputCallResponse { payload: Some(payload) }; 78 } 79 }; 80 81 Ok(Response::new( 82 Box::pin(stream) as Self::StreamingOutputCallStream 83 )) 84 } 85 streaming_input_call( &self, req: Streaming<StreamingInputCallRequest>, ) -> Result<StreamingInputCallResponse>86 async fn streaming_input_call( 87 &self, 88 req: Streaming<StreamingInputCallRequest>, 89 ) -> Result<StreamingInputCallResponse> { 90 let mut stream = req.into_inner(); 91 92 let mut aggregated_payload_size = 0; 93 while let Some(msg) = stream.try_next().await? { 94 aggregated_payload_size += msg.payload.unwrap().body.len() as i32; 95 } 96 97 let res = StreamingInputCallResponse { 98 aggregated_payload_size, 99 }; 100 101 Ok(Response::new(res)) 102 } 103 104 type FullDuplexCallStream = Stream<StreamingOutputCallResponse>; 105 full_duplex_call( &self, req: Streaming<StreamingOutputCallRequest>, ) -> Result<Self::FullDuplexCallStream>106 async fn full_duplex_call( 107 &self, 108 req: Streaming<StreamingOutputCallRequest>, 109 ) -> Result<Self::FullDuplexCallStream> { 110 let mut stream = req.into_inner(); 111 112 if let Some(first_msg) = stream.message().await? { 113 if let Some(echo_status) = first_msg.response_status { 114 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 115 return Err(status); 116 } 117 118 let single_message = tokio_stream::once(Ok(first_msg)); 119 let mut stream = single_message.chain(stream); 120 121 let stream = try_stream! { 122 while let Some(msg) = stream.try_next().await? { 123 if let Some(echo_status) = msg.response_status { 124 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 125 Err(status)?; 126 } 127 128 for param in msg.response_parameters { 129 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await; 130 131 let payload = crate::server_payload(param.size as usize); 132 yield StreamingOutputCallResponse { payload: Some(payload) }; 133 } 134 } 135 }; 136 137 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 138 } else { 139 let stream = tokio_stream::empty(); 140 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 141 } 142 } 143 144 type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>; 145 half_duplex_call( &self, _: Streaming<StreamingOutputCallRequest>, ) -> Result<Self::HalfDuplexCallStream>146 async fn half_duplex_call( 147 &self, 148 _: Streaming<StreamingOutputCallRequest>, 149 ) -> Result<Self::HalfDuplexCallStream> { 150 Err(Status::unimplemented("TODO")) 151 } 152 unimplemented_call(&self, _: Request<Empty>) -> Result<Empty>153 async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> { 154 Err(Status::unimplemented("")) 155 } 156 } 157 158 #[derive(Default)] 159 pub struct UnimplementedService {} 160 161 #[tonic::async_trait] 162 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService { unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty>163 async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> { 164 Err(Status::unimplemented("")) 165 } 166 } 167 168 #[derive(Clone, Default)] 169 pub struct EchoHeadersSvc<S> { 170 inner: S, 171 } 172 173 impl<S: NamedService> NamedService for EchoHeadersSvc<S> { 174 const NAME: &'static str = S::NAME; 175 } 176 177 impl<S> EchoHeadersSvc<S> { new(inner: S) -> Self178 pub fn new(inner: S) -> Self { 179 Self { inner } 180 } 181 } 182 183 impl<S> Service<http::Request<Body>> for EchoHeadersSvc<S> 184 where 185 S: Service<http::Request<Body>, Response = http::Response<Body>> + Send, 186 S::Future: Send + 'static, 187 { 188 type Response = S::Response; 189 type Error = S::Error; 190 type Future = BoxFuture<Self::Response, Self::Error>; 191 poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>>192 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> { 193 Ok(()).into() 194 } 195 call(&mut self, req: http::Request<Body>) -> Self::Future196 fn call(&mut self, req: http::Request<Body>) -> Self::Future { 197 let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned(); 198 199 let trailer_name = HeaderName::from_static("x-grpc-test-echo-trailing-bin"); 200 let echo_trailer = req 201 .headers() 202 .get(&trailer_name) 203 .cloned() 204 .map(|v| HeaderMap::from_iter(std::iter::once((trailer_name, v)))); 205 206 let call = self.inner.call(req); 207 208 Box::pin(async move { 209 let mut res = call.await?; 210 211 if let Some(echo_header) = echo_header { 212 res.headers_mut() 213 .insert("x-grpc-test-echo-initial", echo_header); 214 Ok(res 215 .map(|b| b.with_trailers(async move { echo_trailer.map(Ok) })) 216 .map(Body::new)) 217 } else { 218 Ok(res) 219 } 220 }) 221 } 222 } 223