1 use std::{fmt, sync::Arc};
2 
3 use tokio::sync::mpsc;
4 use tokio_stream::{Stream, StreamExt};
5 use tonic::{Request, Response, Status, Streaming};
6 
7 use super::ReflectionServiceState;
8 use crate::pb::v1alpha::server_reflection_request::MessageRequest;
9 use crate::pb::v1alpha::server_reflection_response::MessageResponse;
10 pub use crate::pb::v1alpha::server_reflection_server::{ServerReflection, ServerReflectionServer};
11 use crate::pb::v1alpha::{
12     ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest,
13     ServerReflectionResponse, ServiceResponse,
14 };
15 
16 /// An implementation for `ServerReflection`.
17 #[derive(Debug)]
18 pub struct ReflectionService {
19     state: Arc<ReflectionServiceState>,
20 }
21 
22 #[tonic::async_trait]
23 impl ServerReflection for ReflectionService {
24     type ServerReflectionInfoStream = ServerReflectionInfoStream;
25 
server_reflection_info( &self, req: Request<Streaming<ServerReflectionRequest>>, ) -> Result<Response<Self::ServerReflectionInfoStream>, Status>26     async fn server_reflection_info(
27         &self,
28         req: Request<Streaming<ServerReflectionRequest>>,
29     ) -> Result<Response<Self::ServerReflectionInfoStream>, Status> {
30         let mut req_rx = req.into_inner();
31         let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1);
32 
33         let state = self.state.clone();
34 
35         tokio::spawn(async move {
36             while let Some(req) = req_rx.next().await {
37                 let Ok(req) = req else {
38                     return;
39                 };
40 
41                 let resp_msg = match req.message_request.clone() {
42                     None => Err(Status::invalid_argument("invalid MessageRequest")),
43                     Some(msg) => match msg {
44                         MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| {
45                             MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
46                                 file_descriptor_proto: vec![fd],
47                             })
48                         }),
49                         MessageRequest::FileContainingSymbol(s) => {
50                             state.symbol_by_name(&s).map(|fd| {
51                                 MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
52                                     file_descriptor_proto: vec![fd],
53                                 })
54                             })
55                         }
56                         MessageRequest::FileContainingExtension(_) => {
57                             Err(Status::not_found("extensions are not supported"))
58                         }
59                         MessageRequest::AllExtensionNumbersOfType(_) => {
60                             // NOTE: Workaround. Some grpc clients (e.g. grpcurl) expect this method not to fail.
61                             // https://github.com/hyperium/tonic/issues/1077
62                             Ok(MessageResponse::AllExtensionNumbersResponse(
63                                 ExtensionNumberResponse::default(),
64                             ))
65                         }
66                         MessageRequest::ListServices(_) => {
67                             Ok(MessageResponse::ListServicesResponse(ListServiceResponse {
68                                 service: state
69                                     .list_services()
70                                     .iter()
71                                     .map(|s| ServiceResponse { name: s.clone() })
72                                     .collect(),
73                             }))
74                         }
75                     },
76                 };
77 
78                 match resp_msg {
79                     Ok(resp_msg) => {
80                         let resp = ServerReflectionResponse {
81                             valid_host: req.host.clone(),
82                             original_request: Some(req.clone()),
83                             message_response: Some(resp_msg),
84                         };
85                         resp_tx.send(Ok(resp)).await.expect("send");
86                     }
87                     Err(status) => {
88                         resp_tx.send(Err(status)).await.expect("send");
89                         return;
90                     }
91                 }
92             }
93         });
94 
95         Ok(Response::new(ServerReflectionInfoStream::new(resp_rx)))
96     }
97 }
98 
99 impl From<ReflectionServiceState> for ReflectionService {
from(state: ReflectionServiceState) -> Self100     fn from(state: ReflectionServiceState) -> Self {
101         Self {
102             state: Arc::new(state),
103         }
104     }
105 }
106 
107 /// A response stream.
108 pub struct ServerReflectionInfoStream {
109     inner: tokio_stream::wrappers::ReceiverStream<Result<ServerReflectionResponse, Status>>,
110 }
111 
112 impl ServerReflectionInfoStream {
new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self113     fn new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self {
114         let inner = tokio_stream::wrappers::ReceiverStream::new(resp_rx);
115         Self { inner }
116     }
117 }
118 
119 impl Stream for ServerReflectionInfoStream {
120     type Item = Result<ServerReflectionResponse, Status>;
121 
poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Option<Self::Item>>122     fn poll_next(
123         mut self: std::pin::Pin<&mut Self>,
124         cx: &mut std::task::Context<'_>,
125     ) -> std::task::Poll<Option<Self::Item>> {
126         std::pin::Pin::new(&mut self.inner).poll_next(cx)
127     }
128 
size_hint(&self) -> (usize, Option<usize>)129     fn size_hint(&self) -> (usize, Option<usize>) {
130         self.inner.size_hint()
131     }
132 }
133 
134 impl fmt::Debug for ServerReflectionInfoStream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result135     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136         f.debug_tuple("ServerReflectionInfoStream").finish()
137     }
138 }
139