1 use std::collections::HashMap;
2 use std::fmt::{Display, Formatter};
3 use std::sync::Arc;
4
5 use prost::{DecodeError, Message};
6 use prost_types::{
7 DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
8 FileDescriptorSet,
9 };
10 use tonic::Status;
11
12 /// v1 interface for the gRPC Reflection Service server.
13 pub mod v1;
14 /// v1alpha interface for the gRPC Reflection Service server.
15 pub mod v1alpha;
16
17 /// A builder used to construct a gRPC Reflection Service.
18 #[derive(Debug)]
19 pub struct Builder<'b> {
20 file_descriptor_sets: Vec<FileDescriptorSet>,
21 encoded_file_descriptor_sets: Vec<&'b [u8]>,
22 include_reflection_service: bool,
23
24 service_names: Vec<String>,
25 use_all_service_names: bool,
26 }
27
28 impl<'b> Builder<'b> {
29 /// Create a new builder that can configure a gRPC Reflection Service.
configure() -> Self30 pub fn configure() -> Self {
31 Builder {
32 file_descriptor_sets: Vec::new(),
33 encoded_file_descriptor_sets: Vec::new(),
34 include_reflection_service: true,
35
36 service_names: Vec::new(),
37 use_all_service_names: true,
38 }
39 }
40
41 /// Registers an instance of `prost_types::FileDescriptorSet` with the gRPC Reflection
42 /// Service builder.
register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self43 pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self {
44 self.file_descriptor_sets.push(file_descriptor_set);
45 self
46 }
47
48 /// Registers a byte slice containing an encoded `prost_types::FileDescriptorSet` with
49 /// the gRPC Reflection Service builder.
register_encoded_file_descriptor_set( mut self, encoded_file_descriptor_set: &'b [u8], ) -> Self50 pub fn register_encoded_file_descriptor_set(
51 mut self,
52 encoded_file_descriptor_set: &'b [u8],
53 ) -> Self {
54 self.encoded_file_descriptor_sets
55 .push(encoded_file_descriptor_set);
56 self
57 }
58
59 /// Serve the gRPC Reflection Service descriptor via the Reflection Service. This is enabled
60 /// by default - set `include` to false to disable.
include_reflection_service(mut self, include: bool) -> Self61 pub fn include_reflection_service(mut self, include: bool) -> Self {
62 self.include_reflection_service = include;
63 self
64 }
65
66 /// Advertise a fully-qualified gRPC service name.
67 ///
68 /// If not called, then all services present in the registered file descriptor sets
69 /// will be advertised.
with_service_name(mut self, name: impl Into<String>) -> Self70 pub fn with_service_name(mut self, name: impl Into<String>) -> Self {
71 self.use_all_service_names = false;
72 self.service_names.push(name.into());
73 self
74 }
75
76 /// Build a v1 gRPC Reflection Service to be served via Tonic.
build_v1( mut self, ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error>77 pub fn build_v1(
78 mut self,
79 ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
80 if self.include_reflection_service {
81 self = self.register_encoded_file_descriptor_set(crate::pb::v1::FILE_DESCRIPTOR_SET);
82 }
83
84 Ok(v1::ServerReflectionServer::new(
85 v1::ReflectionService::from(ReflectionServiceState::new(
86 self.service_names,
87 self.encoded_file_descriptor_sets,
88 self.file_descriptor_sets,
89 self.use_all_service_names,
90 )?),
91 ))
92 }
93
94 /// Build a v1alpha gRPC Reflection Service to be served via Tonic.
build_v1alpha( mut self, ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error>95 pub fn build_v1alpha(
96 mut self,
97 ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error> {
98 if self.include_reflection_service {
99 self =
100 self.register_encoded_file_descriptor_set(crate::pb::v1alpha::FILE_DESCRIPTOR_SET);
101 }
102
103 Ok(v1alpha::ServerReflectionServer::new(
104 v1alpha::ReflectionService::from(ReflectionServiceState::new(
105 self.service_names,
106 self.encoded_file_descriptor_sets,
107 self.file_descriptor_sets,
108 self.use_all_service_names,
109 )?),
110 ))
111 }
112 }
113
114 #[derive(Debug)]
115 struct ReflectionServiceState {
116 service_names: Vec<String>,
117 files: HashMap<String, Arc<FileDescriptorProto>>,
118 symbols: HashMap<String, Arc<FileDescriptorProto>>,
119 }
120
121 impl ReflectionServiceState {
new( service_names: Vec<String>, encoded_file_descriptor_sets: Vec<&[u8]>, mut file_descriptor_sets: Vec<FileDescriptorSet>, use_all_service_names: bool, ) -> Result<Self, Error>122 fn new(
123 service_names: Vec<String>,
124 encoded_file_descriptor_sets: Vec<&[u8]>,
125 mut file_descriptor_sets: Vec<FileDescriptorSet>,
126 use_all_service_names: bool,
127 ) -> Result<Self, Error> {
128 for encoded in encoded_file_descriptor_sets {
129 file_descriptor_sets.push(FileDescriptorSet::decode(encoded)?);
130 }
131
132 let mut state = ReflectionServiceState {
133 service_names,
134 files: HashMap::new(),
135 symbols: HashMap::new(),
136 };
137
138 for fds in file_descriptor_sets {
139 for fd in fds.file {
140 let name = match fd.name.clone() {
141 None => {
142 return Err(Error::InvalidFileDescriptorSet("missing name".to_string()));
143 }
144 Some(n) => n,
145 };
146
147 if state.files.contains_key(&name) {
148 continue;
149 }
150
151 let fd = Arc::new(fd);
152 state.files.insert(name, fd.clone());
153 state.process_file(fd, use_all_service_names)?;
154 }
155 }
156
157 Ok(state)
158 }
159
process_file( &mut self, fd: Arc<FileDescriptorProto>, use_all_service_names: bool, ) -> Result<(), Error>160 fn process_file(
161 &mut self,
162 fd: Arc<FileDescriptorProto>,
163 use_all_service_names: bool,
164 ) -> Result<(), Error> {
165 let prefix = &fd.package.clone().unwrap_or_default();
166
167 for msg in &fd.message_type {
168 self.process_message(fd.clone(), prefix, msg)?;
169 }
170
171 for en in &fd.enum_type {
172 self.process_enum(fd.clone(), prefix, en)?;
173 }
174
175 for service in &fd.service {
176 let service_name = extract_name(prefix, "service", service.name.as_ref())?;
177 if use_all_service_names {
178 self.service_names.push(service_name.clone());
179 }
180 self.symbols.insert(service_name.clone(), fd.clone());
181
182 for method in &service.method {
183 let method_name = extract_name(&service_name, "method", method.name.as_ref())?;
184 self.symbols.insert(method_name, fd.clone());
185 }
186 }
187
188 Ok(())
189 }
190
process_message( &mut self, fd: Arc<FileDescriptorProto>, prefix: &str, msg: &DescriptorProto, ) -> Result<(), Error>191 fn process_message(
192 &mut self,
193 fd: Arc<FileDescriptorProto>,
194 prefix: &str,
195 msg: &DescriptorProto,
196 ) -> Result<(), Error> {
197 let message_name = extract_name(prefix, "message", msg.name.as_ref())?;
198 self.symbols.insert(message_name.clone(), fd.clone());
199
200 for nested in &msg.nested_type {
201 self.process_message(fd.clone(), &message_name, nested)?;
202 }
203
204 for en in &msg.enum_type {
205 self.process_enum(fd.clone(), &message_name, en)?;
206 }
207
208 for field in &msg.field {
209 self.process_field(fd.clone(), &message_name, field)?;
210 }
211
212 for oneof in &msg.oneof_decl {
213 let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?;
214 self.symbols.insert(oneof_name, fd.clone());
215 }
216
217 Ok(())
218 }
219
process_enum( &mut self, fd: Arc<FileDescriptorProto>, prefix: &str, en: &EnumDescriptorProto, ) -> Result<(), Error>220 fn process_enum(
221 &mut self,
222 fd: Arc<FileDescriptorProto>,
223 prefix: &str,
224 en: &EnumDescriptorProto,
225 ) -> Result<(), Error> {
226 let enum_name = extract_name(prefix, "enum", en.name.as_ref())?;
227 self.symbols.insert(enum_name.clone(), fd.clone());
228
229 for value in &en.value {
230 let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?;
231 self.symbols.insert(value_name, fd.clone());
232 }
233
234 Ok(())
235 }
236
process_field( &mut self, fd: Arc<FileDescriptorProto>, prefix: &str, field: &FieldDescriptorProto, ) -> Result<(), Error>237 fn process_field(
238 &mut self,
239 fd: Arc<FileDescriptorProto>,
240 prefix: &str,
241 field: &FieldDescriptorProto,
242 ) -> Result<(), Error> {
243 let field_name = extract_name(prefix, "field", field.name.as_ref())?;
244 self.symbols.insert(field_name, fd);
245 Ok(())
246 }
247
list_services(&self) -> &[String]248 fn list_services(&self) -> &[String] {
249 &self.service_names
250 }
251
symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status>252 fn symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status> {
253 match self.symbols.get(symbol) {
254 None => Err(Status::not_found(format!("symbol '{}' not found", symbol))),
255 Some(fd) => {
256 let mut encoded_fd = Vec::new();
257 if fd.clone().encode(&mut encoded_fd).is_err() {
258 return Err(Status::internal("encoding error"));
259 };
260
261 Ok(encoded_fd)
262 }
263 }
264 }
265
file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status>266 fn file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status> {
267 match self.files.get(filename) {
268 None => Err(Status::not_found(format!("file '{}' not found", filename))),
269 Some(fd) => {
270 let mut encoded_fd = Vec::new();
271 if fd.clone().encode(&mut encoded_fd).is_err() {
272 return Err(Status::internal("encoding error"));
273 }
274
275 Ok(encoded_fd)
276 }
277 }
278 }
279 }
280
extract_name( prefix: &str, name_type: &str, maybe_name: Option<&String>, ) -> Result<String, Error>281 fn extract_name(
282 prefix: &str,
283 name_type: &str,
284 maybe_name: Option<&String>,
285 ) -> Result<String, Error> {
286 match maybe_name {
287 None => Err(Error::InvalidFileDescriptorSet(format!(
288 "missing {} name",
289 name_type
290 ))),
291 Some(name) => {
292 if prefix.is_empty() {
293 Ok(name.to_string())
294 } else {
295 Ok(format!("{}.{}", prefix, name))
296 }
297 }
298 }
299 }
300
301 /// Represents an error in the construction of a gRPC Reflection Service.
302 #[derive(Debug)]
303 pub enum Error {
304 /// An error was encountered decoding a `prost_types::FileDescriptorSet` from a buffer.
305 DecodeError(prost::DecodeError),
306 /// An invalid `prost_types::FileDescriptorProto` was encountered.
307 InvalidFileDescriptorSet(String),
308 }
309
310 impl From<DecodeError> for Error {
from(e: DecodeError) -> Self311 fn from(e: DecodeError) -> Self {
312 Error::DecodeError(e)
313 }
314 }
315
316 impl std::error::Error for Error {}
317
318 impl Display for Error {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result319 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
320 match self {
321 Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"),
322 Error::InvalidFileDescriptorSet(s) => {
323 write!(f, "invalid FileDescriptorSet - {}", s)
324 }
325 }
326 }
327 }
328