1 use crate::pb::{self, *}; 2 use async_stream::try_stream; 3 use futures_util::{stream, StreamExt, TryStreamExt}; 4 use http::header::{HeaderMap, HeaderName, HeaderValue}; 5 use http_body::Body; 6 use std::future::Future; 7 use std::pin::Pin; 8 use std::task::{Context, Poll}; 9 use std::time::Duration; 10 use tonic::{body::BoxBody, transport::NamedService, Code, Request, Response, Status}; 11 use tower::Service; 12 13 pub use pb::test_service_server::TestServiceServer; 14 pub use pb::unimplemented_service_server::UnimplementedServiceServer; 15 16 #[derive(Default, Clone)] 17 pub struct TestService; 18 19 type Result<T> = std::result::Result<Response<T>, Status>; 20 type Streaming<T> = Request<tonic::Streaming<T>>; 21 type Stream<T> = Pin< 22 Box<dyn futures_core::Stream<Item = std::result::Result<T, Status>> + Send + Sync + 'static>, 23 >; 24 25 #[tonic::async_trait] 26 impl pb::test_service_server::TestService for TestService { 27 async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> { 28 Ok(Response::new(Empty {})) 29 } 30 31 async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> { 32 let req = request.into_inner(); 33 34 if let Some(echo_status) = req.response_status { 35 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 36 return Err(status); 37 } 38 39 let res_size = if req.response_size >= 0 { 40 req.response_size as usize 41 } else { 42 let status = Status::new(Code::InvalidArgument, "response_size cannot be negative"); 43 return Err(status); 44 }; 45 46 let res = SimpleResponse { 47 payload: Some(Payload { 48 body: vec![0; res_size], 49 ..Default::default() 50 }), 51 ..Default::default() 52 }; 53 54 Ok(Response::new(res)) 55 } 56 57 async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> { 58 unimplemented!() 59 } 60 61 type StreamingOutputCallStream = Stream<StreamingOutputCallResponse>; 62 63 async fn streaming_output_call( 64 &self, 65 req: Request<StreamingOutputCallRequest>, 66 ) -> Result<Self::StreamingOutputCallStream> { 67 let StreamingOutputCallRequest { 68 response_parameters, 69 .. 70 } = req.into_inner(); 71 72 let stream = try_stream! { 73 for param in response_parameters { 74 tokio::time::delay_for(Duration::from_micros(param.interval_us as u64)).await; 75 76 let payload = crate::server_payload(param.size as usize); 77 yield StreamingOutputCallResponse { payload: Some(payload) }; 78 } 79 }; 80 81 Ok(Response::new( 82 Box::pin(stream) as Self::StreamingOutputCallStream 83 )) 84 } 85 86 async fn streaming_input_call( 87 &self, 88 req: Streaming<StreamingInputCallRequest>, 89 ) -> Result<StreamingInputCallResponse> { 90 let mut stream = req.into_inner(); 91 92 let mut aggregated_payload_size = 0 as i32; 93 while let Some(msg) = stream.try_next().await? { 94 aggregated_payload_size += msg.payload.unwrap().body.len() as i32; 95 } 96 97 let res = StreamingInputCallResponse { 98 aggregated_payload_size, 99 }; 100 101 Ok(Response::new(res)) 102 } 103 104 type FullDuplexCallStream = Stream<StreamingOutputCallResponse>; 105 106 async fn full_duplex_call( 107 &self, 108 req: Streaming<StreamingOutputCallRequest>, 109 ) -> Result<Self::FullDuplexCallStream> { 110 let mut stream = req.into_inner(); 111 112 if let Some(first_msg) = stream.message().await? { 113 if let Some(echo_status) = first_msg.response_status { 114 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 115 return Err(status); 116 } 117 118 let single_message = stream::iter(vec![Ok(first_msg)]); 119 let mut stream = single_message.chain(stream); 120 121 let stream = try_stream! { 122 while let Some(msg) = stream.try_next().await? { 123 if let Some(echo_status) = msg.response_status { 124 let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); 125 Err(status)?; 126 } 127 128 for param in msg.response_parameters { 129 tokio::time::delay_for(Duration::from_micros(param.interval_us as u64)).await; 130 131 let payload = crate::server_payload(param.size as usize); 132 yield StreamingOutputCallResponse { payload: Some(payload) }; 133 } 134 } 135 }; 136 137 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 138 } else { 139 let stream = stream::empty(); 140 Ok(Response::new(Box::pin(stream) as Self::FullDuplexCallStream)) 141 } 142 } 143 144 type HalfDuplexCallStream = Stream<StreamingOutputCallResponse>; 145 146 async fn half_duplex_call( 147 &self, 148 _: Streaming<StreamingOutputCallRequest>, 149 ) -> Result<Self::HalfDuplexCallStream> { 150 Err(Status::unimplemented("TODO")) 151 } 152 153 async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> { 154 Err(Status::unimplemented("")) 155 } 156 } 157 158 #[derive(Default)] 159 pub struct UnimplementedService; 160 161 #[tonic::async_trait] 162 impl pb::unimplemented_service_server::UnimplementedService for UnimplementedService { 163 async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> { 164 Err(Status::unimplemented("")) 165 } 166 } 167 168 #[derive(Clone, Default)] 169 pub struct EchoHeadersSvc<S> { 170 inner: S, 171 } 172 173 impl<S: NamedService> NamedService for EchoHeadersSvc<S> { 174 const NAME: &'static str = S::NAME; 175 } 176 177 impl<S> EchoHeadersSvc<S> { 178 pub fn new(inner: S) -> Self { 179 Self { inner } 180 } 181 } 182 183 impl<S> Service<http::Request<hyper::Body>> for EchoHeadersSvc<S> 184 where 185 S: Service<http::Request<hyper::Body>, Response = http::Response<BoxBody>> + Send, 186 S::Future: Send + 'static, 187 { 188 type Response = S::Response; 189 type Error = S::Error; 190 type Future = Pin< 191 Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send + 'static>, 192 >; 193 194 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> { 195 Ok(()).into() 196 } 197 198 fn call(&mut self, req: http::Request<hyper::Body>) -> Self::Future { 199 let echo_header = req 200 .headers() 201 .get("x-grpc-test-echo-initial") 202 .map(Clone::clone); 203 204 let echo_trailer = req 205 .headers() 206 .get("x-grpc-test-echo-trailing-bin") 207 .map(Clone::clone) 208 .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v)); 209 210 let call = self.inner.call(req); 211 212 Box::pin(async move { 213 let mut res = call.await?; 214 215 if let Some(echo_header) = echo_header { 216 res.headers_mut() 217 .insert("x-grpc-test-echo-initial", echo_header); 218 } 219 220 Ok(res 221 .map(|b| MergeTrailers::new(b, echo_trailer)) 222 .map(BoxBody::new)) 223 }) 224 } 225 } 226 227 pub struct MergeTrailers<B> { 228 inner: B, 229 trailer: Option<(HeaderName, HeaderValue)>, 230 } 231 232 impl<B> MergeTrailers<B> { 233 pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self { 234 Self { inner, trailer } 235 } 236 } 237 238 impl<B: Body + Unpin> Body for MergeTrailers<B> { 239 type Data = B::Data; 240 type Error = B::Error; 241 242 fn poll_data( 243 mut self: Pin<&mut Self>, 244 cx: &mut Context<'_>, 245 ) -> Poll<Option<std::result::Result<Self::Data, Self::Error>>> { 246 Pin::new(&mut self.inner).poll_data(cx) 247 } 248 249 fn poll_trailers( 250 mut self: Pin<&mut Self>, 251 cx: &mut Context<'_>, 252 ) -> Poll<std::result::Result<Option<HeaderMap>, Self::Error>> { 253 Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { 254 h.map(|mut headers| { 255 if let Some((key, value)) = &self.trailer { 256 headers.insert(key.clone(), value.clone()); 257 } 258 259 headers 260 }) 261 }) 262 } 263 } 264