1 use std::{fmt, sync::Arc}; 2 3 use tokio::sync::mpsc; 4 use tokio_stream::{Stream, StreamExt}; 5 use tonic::{Request, Response, Status, Streaming}; 6 7 use super::ReflectionServiceState; 8 use crate::pb::v1::server_reflection_request::MessageRequest; 9 use crate::pb::v1::server_reflection_response::MessageResponse; 10 pub use crate::pb::v1::server_reflection_server::{ServerReflection, ServerReflectionServer}; 11 use crate::pb::v1::{ 12 ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest, 13 ServerReflectionResponse, ServiceResponse, 14 }; 15 16 /// An implementation for `ServerReflection`. 17 #[derive(Debug)] 18 pub struct ReflectionService { 19 state: Arc<ReflectionServiceState>, 20 } 21 22 #[tonic::async_trait] 23 impl ServerReflection for ReflectionService { 24 type ServerReflectionInfoStream = ServerReflectionInfoStream; 25 server_reflection_info( &self, req: Request<Streaming<ServerReflectionRequest>>, ) -> Result<Response<Self::ServerReflectionInfoStream>, Status>26 async fn server_reflection_info( 27 &self, 28 req: Request<Streaming<ServerReflectionRequest>>, 29 ) -> Result<Response<Self::ServerReflectionInfoStream>, Status> { 30 let mut req_rx = req.into_inner(); 31 let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1); 32 33 let state = self.state.clone(); 34 35 tokio::spawn(async move { 36 while let Some(req) = req_rx.next().await { 37 let Ok(req) = req else { 38 return; 39 }; 40 41 let resp_msg = match req.message_request.clone() { 42 None => Err(Status::invalid_argument("invalid MessageRequest")), 43 Some(msg) => match msg { 44 MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| { 45 MessageResponse::FileDescriptorResponse(FileDescriptorResponse { 46 file_descriptor_proto: vec![fd], 47 }) 48 }), 49 MessageRequest::FileContainingSymbol(s) => { 50 state.symbol_by_name(&s).map(|fd| { 51 MessageResponse::FileDescriptorResponse(FileDescriptorResponse { 52 file_descriptor_proto: vec![fd], 53 }) 54 }) 55 } 56 MessageRequest::FileContainingExtension(_) => { 57 Err(Status::not_found("extensions are not supported")) 58 } 59 MessageRequest::AllExtensionNumbersOfType(_) => { 60 // NOTE: Workaround. Some grpc clients (e.g. grpcurl) expect this method not to fail. 61 // https://github.com/hyperium/tonic/issues/1077 62 Ok(MessageResponse::AllExtensionNumbersResponse( 63 ExtensionNumberResponse::default(), 64 )) 65 } 66 MessageRequest::ListServices(_) => { 67 Ok(MessageResponse::ListServicesResponse(ListServiceResponse { 68 service: state 69 .list_services() 70 .iter() 71 .map(|s| ServiceResponse { name: s.clone() }) 72 .collect(), 73 })) 74 } 75 }, 76 }; 77 78 match resp_msg { 79 Ok(resp_msg) => { 80 let resp = ServerReflectionResponse { 81 valid_host: req.host.clone(), 82 original_request: Some(req.clone()), 83 message_response: Some(resp_msg), 84 }; 85 resp_tx.send(Ok(resp)).await.expect("send"); 86 } 87 Err(status) => { 88 resp_tx.send(Err(status)).await.expect("send"); 89 return; 90 } 91 } 92 } 93 }); 94 95 Ok(Response::new(ServerReflectionInfoStream::new(resp_rx))) 96 } 97 } 98 99 impl From<ReflectionServiceState> for ReflectionService { from(state: ReflectionServiceState) -> Self100 fn from(state: ReflectionServiceState) -> Self { 101 Self { 102 state: Arc::new(state), 103 } 104 } 105 } 106 107 /// A response stream. 108 pub struct ServerReflectionInfoStream { 109 inner: tokio_stream::wrappers::ReceiverStream<Result<ServerReflectionResponse, Status>>, 110 } 111 112 impl ServerReflectionInfoStream { new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self113 fn new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self { 114 let inner = tokio_stream::wrappers::ReceiverStream::new(resp_rx); 115 Self { inner } 116 } 117 } 118 119 impl Stream for ServerReflectionInfoStream { 120 type Item = Result<ServerReflectionResponse, Status>; 121 poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Option<Self::Item>>122 fn poll_next( 123 mut self: std::pin::Pin<&mut Self>, 124 cx: &mut std::task::Context<'_>, 125 ) -> std::task::Poll<Option<Self::Item>> { 126 std::pin::Pin::new(&mut self.inner).poll_next(cx) 127 } 128 size_hint(&self) -> (usize, Option<usize>)129 fn size_hint(&self) -> (usize, Option<usize>) { 130 self.inner.size_hint() 131 } 132 } 133 134 impl fmt::Debug for ServerReflectionInfoStream { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 136 f.debug_tuple("ServerReflectionInfoStream").finish() 137 } 138 } 139