1 use crate::pb::{self, *}; 2 use async_stream::try_stream; 3 use http::header::{HeaderMap, HeaderName, HeaderValue}; 4 use http_body::Body; 5 use std::future::Future; 6 use std::pin::Pin; 7 use std::task::{Context, Poll}; 8 use std::time::Duration; 9 use tokio_stream::StreamExt; 10 use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status}; 11 use tower::Service; 12 13 pub use pb::test_service_server::TestServiceServer; 14 pub use pb::unimplemented_service_server::UnimplementedServiceServer; 15 16 #[derive(Default, Clone)] 17 pub struct TestService; 18 19 type Result<T> = std::result::Result<Response<T>, Status>; 20 type Streaming<T> = Request<tonic::Streaming<T>>; 21 type Stream<T> = 22 Pin<Box<dyn tokio_stream::Stream<Item = std::result::Result<T, Status>> + Send + 'static>>; 23 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = std::result::Result<T, E>> + Send + 'static>>; 24 25 #[tonic::async_trait] 26 impl pb::test_service_server::TestService for TestService { 27 async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> { 28 Ok(Response::new(Empty {})) 29 } 30 31 async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> { 32 let req = request.into_inner(); 33 34 if let Some(echo_status) = req.response_status { 35 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 36 return Err(status); 37 } 38 39 let res_size = if req.response_size >= 0 { 40 req.response_size as usize 41 } else { 42 let status = Status::new(Code::InvalidArgument, "response_size cannot be negative"); 43 return Err(status); 44 }; 45 46 let res = SimpleResponse { 47 payload: Some(Payload { 48 body: vec![0; res_size], 49 ..Default::default() 50 }), 51 ..Default::default() 52 }; 53 54 Ok(Response::new(res)) 55 } 56 57 async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> { 58 unimplemented!() 59 } 60 61 type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>; 62 63 async fn streaming_output_call( 64 &self, 65 req: Request<StreamingOutputCallRequest>, 66 ) -> Result<Self::StreamingOutputCallStream> { 67 let StreamingOutputCallRequest { 68 response_parameters, 69 .. 70 } = req.into_inner(); 71 72 let stream = try_stream! { 73 for param in response_parameters { 74 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await; 75 76 let payload = crate::server_payload(param.size as usize); 77 yield StreamingOutputCallResponse { payload: Some(payload) }; 78 } 79 }; 80 81 Ok(Response::new( 82 Box::pin(stream) as Self::StreamingOutputCallStream 83 )) 84 } 85 86 async fn streaming_input_call( 87 &self, 88 req: Streaming<StreamingInputCallRequest>, 89 ) -> Result<StreamingInputCallResponse> { 90 let mut stream = req.into_inner(); 91 92 let mut aggregated_payload_size = 0; 93 while let Some(msg) = stream.try_next().await? { 94 aggregated_payload_size += msg.payload.unwrap().body.len() as i32; 95 } 96 97 let res = StreamingInputCallResponse { 98 aggregated_payload_size, 99 }; 100 101 Ok(Response::new(res)) 102 } 103 104 type FullDuplexCallStream = Stream<StreamingOutputCallResponse>; 105 106 async fn full_duplex_call( 107 &self, 108 req: Streaming<StreamingOutputCallRequest>, 109 ) -> Result<Self::FullDuplexCallStream> { 110 let mut stream = req.into_inner(); 111 112 if let Some(first_msg) = stream.message().await? { 113 if let Some(echo_status) = first_msg.response_status { 114 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 115 return Err(status); 116 } 117 118 let single_message = tokio_stream::once(Ok(first_msg)); 119 let mut stream = single_message.chain(stream); 120 121 let stream = try_stream! { 122 while let Some(msg) = stream.try_next().await? { 123 if let Some(echo_status) = msg.response_status { 124 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 125 Err(status)?; 126 } 127 128 for param in msg.response_parameters { 129 tokio::time::sleep(Duration::from_micros(param.interval_us as u64)).await; 130 131 let payload = crate::server_payload(param.size as usize); 132 yield StreamingOutputCallResponse { payload: Some(payload) }; 133 } 134 } 135 }; 136 137 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 138 } else { 139 let stream = tokio_stream::empty(); 140 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 141 } 142 } 143 144 type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>; 145 146 async fn half_duplex_call( 147 &self, 148 _: Streaming<StreamingOutputCallRequest>, 149 ) -> Result<Self::HalfDuplexCallStream> { 150 Err(Status::unimplemented("TODO")) 151 } 152 153 async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> { 154 Err(Status::unimplemented("")) 155 } 156 } 157 158 #[derive(Default)] 159 pub struct UnimplementedService; 160 161 #[tonic::async_trait] 162 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService { 163 async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> { 164 Err(Status::unimplemented("")) 165 } 166 } 167 168 #[derive(Clone, Default)] 169 pub struct EchoHeadersSvc<S> { 170 inner: S, 171 } 172 173 impl<S: NamedService> NamedService for EchoHeadersSvc<S> { 174 const NAME: &'static str = S::NAME; 175 } 176 177 impl<S> EchoHeadersSvc<S> { 178 pub fn new(inner: S) -> Self { 179 Self { inner } 180 } 181 } 182 183 impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S> 184 where 185 S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send, 186 S::Future: Send + 'static, 187 { 188 type Response = S::Response; 189 type Error = S::Error; 190 type Future = BoxFuture<Self::Response, Self::Error>; 191 192 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> { 193 Ok(()).into() 194 } 195 196 fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future { 197 let echo_header = req 198 .headers() 199 .get("x-grpc-test-echo-initial") 200 .map(Clone::clone); 201 202 let echo_trailer = req 203 .headers() 204 .get("x-grpc-test-echo-trailing-bin") 205 .map(Clone::clone) 206 .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v)); 207 208 let call = self.inner.call(req); 209 210 Box::pin(async move { 211 let mut res = call.await?; 212 213 if let Some(echo_header) = echo_header { 214 res.headers_mut() 215 .insert("x-grpc-test-echo-initial", echo_header); 216 Ok(res 217 .map(|b| MergeTrailers::new(b, echo_trailer)) 218 .map(BoxBody::new)) 219 } else { 220 Ok(res) 221 } 222 }) 223 } 224 } 225 226 pub struct MergeTrailers<B> { 227 inner: B, 228 trailer: Option<(HeaderName, HeaderValue)>, 229 } 230 231 impl<B> MergeTrailers<B> { 232 pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self { 233 Self { inner, trailer } 234 } 235 } 236 237 impl<B: Body + Unpin> Body for MergeTrailers<B> { 238 type Data = B::Data; 239 type Error = B::Error; 240 241 fn poll_data( 242 mut self: Pin<&mut Self>, 243 cx: &mut Context<'_>, 244 ) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> { 245 Pin::new(&mut self.inner).poll_data(cx) 246 } 247 248 fn poll_trailers( 249 mut self: Pin<&mut Self>, 250 cx: &mut Context<'_>, 251 ) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> { 252 Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { 253 h.map(|mut headers| { 254 if let Some((key, value)) = &self.trailer { 255 headers.insert(key.clone(), value.clone()); 256 } 257 258 headers 259 }) 260 }) 261 } 262 } 263