xref: /tonic/tonic-reflection/src/server/mod.rs (revision f4a879db)
1 use std::collections::HashMap;
2 use std::fmt::{Display, Formatter};
3 use std::sync::Arc;
4 
5 use prost::{DecodeError, Message};
6 use prost_types::{
7     DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
8     FileDescriptorSet,
9 };
10 use tonic::Status;
11 
12 /// v1 interface for the gRPC Reflection Service server.
13 pub mod v1;
14 /// v1alpha interface for the gRPC Reflection Service server.
15 pub mod v1alpha;
16 
17 /// A builder used to construct a gRPC Reflection Service.
18 #[derive(Debug)]
19 pub struct Builder<'b> {
20     file_descriptor_sets: Vec<FileDescriptorSet>,
21     encoded_file_descriptor_sets: Vec<&'b [u8]>,
22     include_reflection_service: bool,
23 
24     service_names: Vec<String>,
25     use_all_service_names: bool,
26 }
27 
28 impl<'b> Builder<'b> {
29     /// Create a new builder that can configure a gRPC Reflection Service.
configure() -> Self30     pub fn configure() -> Self {
31         Builder {
32             file_descriptor_sets: Vec::new(),
33             encoded_file_descriptor_sets: Vec::new(),
34             include_reflection_service: true,
35 
36             service_names: Vec::new(),
37             use_all_service_names: true,
38         }
39     }
40 
41     /// Registers an instance of `prost_types::FileDescriptorSet` with the gRPC Reflection
42     /// Service builder.
register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self43     pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self {
44         self.file_descriptor_sets.push(file_descriptor_set);
45         self
46     }
47 
48     /// Registers a byte slice containing an encoded `prost_types::FileDescriptorSet` with
49     /// the gRPC Reflection Service builder.
register_encoded_file_descriptor_set( mut self, encoded_file_descriptor_set: &'b [u8], ) -> Self50     pub fn register_encoded_file_descriptor_set(
51         mut self,
52         encoded_file_descriptor_set: &'b [u8],
53     ) -> Self {
54         self.encoded_file_descriptor_sets
55             .push(encoded_file_descriptor_set);
56         self
57     }
58 
59     /// Serve the gRPC Reflection Service descriptor via the Reflection Service. This is enabled
60     /// by default - set `include` to false to disable.
include_reflection_service(mut self, include: bool) -> Self61     pub fn include_reflection_service(mut self, include: bool) -> Self {
62         self.include_reflection_service = include;
63         self
64     }
65 
66     /// Advertise a fully-qualified gRPC service name.
67     ///
68     /// If not called, then all services present in the registered file descriptor sets
69     /// will be advertised.
with_service_name(mut self, name: impl Into<String>) -> Self70     pub fn with_service_name(mut self, name: impl Into<String>) -> Self {
71         self.use_all_service_names = false;
72         self.service_names.push(name.into());
73         self
74     }
75 
76     /// Build a v1 gRPC Reflection Service to be served via Tonic.
build_v1( mut self, ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error>77     pub fn build_v1(
78         mut self,
79     ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
80         if self.include_reflection_service {
81             self = self.register_encoded_file_descriptor_set(crate::pb::v1::FILE_DESCRIPTOR_SET);
82         }
83 
84         Ok(v1::ServerReflectionServer::new(
85             v1::ReflectionService::from(ReflectionServiceState::new(
86                 self.service_names,
87                 self.encoded_file_descriptor_sets,
88                 self.file_descriptor_sets,
89                 self.use_all_service_names,
90             )?),
91         ))
92     }
93 
94     /// Build a v1alpha gRPC Reflection Service to be served via Tonic.
build_v1alpha( mut self, ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error>95     pub fn build_v1alpha(
96         mut self,
97     ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error> {
98         if self.include_reflection_service {
99             self =
100                 self.register_encoded_file_descriptor_set(crate::pb::v1alpha::FILE_DESCRIPTOR_SET);
101         }
102 
103         Ok(v1alpha::ServerReflectionServer::new(
104             v1alpha::ReflectionService::from(ReflectionServiceState::new(
105                 self.service_names,
106                 self.encoded_file_descriptor_sets,
107                 self.file_descriptor_sets,
108                 self.use_all_service_names,
109             )?),
110         ))
111     }
112 }
113 
114 #[derive(Debug)]
115 struct ReflectionServiceState {
116     service_names: Vec<String>,
117     files: HashMap<String, Arc<FileDescriptorProto>>,
118     symbols: HashMap<String, Arc<FileDescriptorProto>>,
119 }
120 
121 impl ReflectionServiceState {
new( service_names: Vec<String>, encoded_file_descriptor_sets: Vec<&[u8]>, mut file_descriptor_sets: Vec<FileDescriptorSet>, use_all_service_names: bool, ) -> Result<Self, Error>122     fn new(
123         service_names: Vec<String>,
124         encoded_file_descriptor_sets: Vec<&[u8]>,
125         mut file_descriptor_sets: Vec<FileDescriptorSet>,
126         use_all_service_names: bool,
127     ) -> Result<Self, Error> {
128         for encoded in encoded_file_descriptor_sets {
129             file_descriptor_sets.push(FileDescriptorSet::decode(encoded)?);
130         }
131 
132         let mut state = ReflectionServiceState {
133             service_names,
134             files: HashMap::new(),
135             symbols: HashMap::new(),
136         };
137 
138         for fds in file_descriptor_sets {
139             for fd in fds.file {
140                 let name = match fd.name.clone() {
141                     None => {
142                         return Err(Error::InvalidFileDescriptorSet("missing name".to_string()));
143                     }
144                     Some(n) => n,
145                 };
146 
147                 if state.files.contains_key(&name) {
148                     continue;
149                 }
150 
151                 let fd = Arc::new(fd);
152                 state.files.insert(name, fd.clone());
153                 state.process_file(fd, use_all_service_names)?;
154             }
155         }
156 
157         Ok(state)
158     }
159 
process_file( &mut self, fd: Arc<FileDescriptorProto>, use_all_service_names: bool, ) -> Result<(), Error>160     fn process_file(
161         &mut self,
162         fd: Arc<FileDescriptorProto>,
163         use_all_service_names: bool,
164     ) -> Result<(), Error> {
165         let prefix = &fd.package.clone().unwrap_or_default();
166 
167         for msg in &fd.message_type {
168             self.process_message(fd.clone(), prefix, msg)?;
169         }
170 
171         for en in &fd.enum_type {
172             self.process_enum(fd.clone(), prefix, en)?;
173         }
174 
175         for service in &fd.service {
176             let service_name = extract_name(prefix, "service", service.name.as_ref())?;
177             if use_all_service_names {
178                 self.service_names.push(service_name.clone());
179             }
180             self.symbols.insert(service_name.clone(), fd.clone());
181 
182             for method in &service.method {
183                 let method_name = extract_name(&service_name, "method", method.name.as_ref())?;
184                 self.symbols.insert(method_name, fd.clone());
185             }
186         }
187 
188         Ok(())
189     }
190 
process_message( &mut self, fd: Arc<FileDescriptorProto>, prefix: &str, msg: &DescriptorProto, ) -> Result<(), Error>191     fn process_message(
192         &mut self,
193         fd: Arc<FileDescriptorProto>,
194         prefix: &str,
195         msg: &DescriptorProto,
196     ) -> Result<(), Error> {
197         let message_name = extract_name(prefix, "message", msg.name.as_ref())?;
198         self.symbols.insert(message_name.clone(), fd.clone());
199 
200         for nested in &msg.nested_type {
201             self.process_message(fd.clone(), &message_name, nested)?;
202         }
203 
204         for en in &msg.enum_type {
205             self.process_enum(fd.clone(), &message_name, en)?;
206         }
207 
208         for field in &msg.field {
209             self.process_field(fd.clone(), &message_name, field)?;
210         }
211 
212         for oneof in &msg.oneof_decl {
213             let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?;
214             self.symbols.insert(oneof_name, fd.clone());
215         }
216 
217         Ok(())
218     }
219 
process_enum( &mut self, fd: Arc<FileDescriptorProto>, prefix: &str, en: &EnumDescriptorProto, ) -> Result<(), Error>220     fn process_enum(
221         &mut self,
222         fd: Arc<FileDescriptorProto>,
223         prefix: &str,
224         en: &EnumDescriptorProto,
225     ) -> Result<(), Error> {
226         let enum_name = extract_name(prefix, "enum", en.name.as_ref())?;
227         self.symbols.insert(enum_name.clone(), fd.clone());
228 
229         for value in &en.value {
230             let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?;
231             self.symbols.insert(value_name, fd.clone());
232         }
233 
234         Ok(())
235     }
236 
process_field( &mut self, fd: Arc<FileDescriptorProto>, prefix: &str, field: &FieldDescriptorProto, ) -> Result<(), Error>237     fn process_field(
238         &mut self,
239         fd: Arc<FileDescriptorProto>,
240         prefix: &str,
241         field: &FieldDescriptorProto,
242     ) -> Result<(), Error> {
243         let field_name = extract_name(prefix, "field", field.name.as_ref())?;
244         self.symbols.insert(field_name, fd);
245         Ok(())
246     }
247 
list_services(&self) -> &[String]248     fn list_services(&self) -> &[String] {
249         &self.service_names
250     }
251 
symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status>252     fn symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status> {
253         match self.symbols.get(symbol) {
254             None => Err(Status::not_found(format!("symbol '{}' not found", symbol))),
255             Some(fd) => {
256                 let mut encoded_fd = Vec::new();
257                 if fd.clone().encode(&mut encoded_fd).is_err() {
258                     return Err(Status::internal("encoding error"));
259                 };
260 
261                 Ok(encoded_fd)
262             }
263         }
264     }
265 
file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status>266     fn file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status> {
267         match self.files.get(filename) {
268             None => Err(Status::not_found(format!("file '{}' not found", filename))),
269             Some(fd) => {
270                 let mut encoded_fd = Vec::new();
271                 if fd.clone().encode(&mut encoded_fd).is_err() {
272                     return Err(Status::internal("encoding error"));
273                 }
274 
275                 Ok(encoded_fd)
276             }
277         }
278     }
279 }
280 
extract_name( prefix: &str, name_type: &str, maybe_name: Option<&String>, ) -> Result<String, Error>281 fn extract_name(
282     prefix: &str,
283     name_type: &str,
284     maybe_name: Option<&String>,
285 ) -> Result<String, Error> {
286     match maybe_name {
287         None => Err(Error::InvalidFileDescriptorSet(format!(
288             "missing {} name",
289             name_type
290         ))),
291         Some(name) => {
292             if prefix.is_empty() {
293                 Ok(name.to_string())
294             } else {
295                 Ok(format!("{}.{}", prefix, name))
296             }
297         }
298     }
299 }
300 
301 /// Represents an error in the construction of a gRPC Reflection Service.
302 #[derive(Debug)]
303 pub enum Error {
304     /// An error was encountered decoding a `prost_types::FileDescriptorSet` from a buffer.
305     DecodeError(prost::DecodeError),
306     /// An invalid `prost_types::FileDescriptorProto` was encountered.
307     InvalidFileDescriptorSet(String),
308 }
309 
310 impl From<DecodeError> for Error {
from(e: DecodeError) -> Self311     fn from(e: DecodeError) -> Self {
312         Error::DecodeError(e)
313     }
314 }
315 
316 impl std::error::Error for Error {}
317 
318 impl Display for Error {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result319     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
320         match self {
321             Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"),
322             Error::InvalidFileDescriptorSet(s) => {
323                 write!(f, "invalid FileDescriptorSet - {}", s)
324             }
325         }
326     }
327 }
328