xref: /tonic/tonic-build/src/server.rs (revision 60b131d2)
1 use std::collections::HashSet;
2 
3 use super::{Attributes, Method, Service};
4 use crate::{
5     format_method_name, format_method_path, format_service_name, generate_doc_comment,
6     generate_doc_comments, naive_snake_case,
7 };
8 use proc_macro2::{Span, TokenStream};
9 use quote::quote;
10 use syn::{Ident, Lit, LitStr};
11 
12 #[allow(clippy::too_many_arguments)]
generate_internal<T: Service>( service: &T, emit_package: bool, proto_path: &str, compile_well_known_types: bool, attributes: &Attributes, disable_comments: &HashSet<String>, use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream13 pub(crate) fn generate_internal<T: Service>(
14     service: &T,
15     emit_package: bool,
16     proto_path: &str,
17     compile_well_known_types: bool,
18     attributes: &Attributes,
19     disable_comments: &HashSet<String>,
20     use_arc_self: bool,
21     generate_default_stubs: bool,
22 ) -> TokenStream {
23     let methods = generate_methods(
24         service,
25         emit_package,
26         proto_path,
27         compile_well_known_types,
28         use_arc_self,
29         generate_default_stubs,
30     );
31 
32     let server_service = quote::format_ident!("{}Server", service.name());
33     let server_trait = quote::format_ident!("{}", service.name());
34     let server_mod = quote::format_ident!("{}_server", naive_snake_case(service.name()));
35     let generated_trait = generate_trait(
36         service,
37         emit_package,
38         proto_path,
39         compile_well_known_types,
40         server_trait.clone(),
41         disable_comments,
42         use_arc_self,
43         generate_default_stubs,
44     );
45     let package = if emit_package { service.package() } else { "" };
46     // Transport based implementations
47     let service_name = format_service_name(service, emit_package);
48 
49     let service_doc = if disable_comments.contains(&service_name) {
50         TokenStream::new()
51     } else {
52         generate_doc_comments(service.comment())
53     };
54 
55     let named = generate_named(&server_service, &service_name);
56     let mod_attributes = attributes.for_mod(package);
57     let struct_attributes = attributes.for_struct(&service_name);
58 
59     let configure_compression_methods = quote! {
60         /// Enable decompressing requests with the given encoding.
61         #[must_use]
62         pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
63             self.accept_compression_encodings.enable(encoding);
64             self
65         }
66 
67         /// Compress responses with the given encoding, if the client supports it.
68         #[must_use]
69         pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
70             self.send_compression_encodings.enable(encoding);
71             self
72         }
73     };
74 
75     let configure_max_message_size_methods = quote! {
76         /// Limits the maximum size of a decoded message.
77         ///
78         /// Default: `4MB`
79         #[must_use]
80         pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
81             self.max_decoding_message_size = Some(limit);
82             self
83         }
84 
85         /// Limits the maximum size of an encoded message.
86         ///
87         /// Default: `usize::MAX`
88         #[must_use]
89         pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
90             self.max_encoding_message_size = Some(limit);
91             self
92         }
93     };
94 
95     quote! {
96         /// Generated server implementations.
97         #(#mod_attributes)*
98         pub mod #server_mod {
99             #![allow(
100                 unused_variables,
101                 dead_code,
102                 missing_docs,
103                 clippy::wildcard_imports,
104                 // will trigger if compression is disabled
105                 clippy::let_unit_value,
106             )]
107             use tonic::codegen::*;
108 
109             #generated_trait
110 
111             #service_doc
112             #(#struct_attributes)*
113             #[derive(Debug)]
114             pub struct #server_service<T> {
115                 inner: Arc<T>,
116                 accept_compression_encodings: EnabledCompressionEncodings,
117                 send_compression_encodings: EnabledCompressionEncodings,
118                 max_decoding_message_size: Option<usize>,
119                 max_encoding_message_size: Option<usize>,
120             }
121 
122             impl<T> #server_service<T> {
123                 pub fn new(inner: T) -> Self {
124                     Self::from_arc(Arc::new(inner))
125                 }
126 
127                 pub fn from_arc(inner: Arc<T>) -> Self {
128                     Self {
129                         inner,
130                         accept_compression_encodings: Default::default(),
131                         send_compression_encodings: Default::default(),
132                         max_decoding_message_size: None,
133                         max_encoding_message_size: None,
134                     }
135                 }
136 
137                 pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
138                 where
139                     F: tonic::service::Interceptor,
140                 {
141                     InterceptedService::new(Self::new(inner), interceptor)
142                 }
143 
144                 #configure_compression_methods
145 
146                 #configure_max_message_size_methods
147             }
148 
149             impl<T, B> tonic::codegen::Service<http::Request<B>> for #server_service<T>
150                 where
151                     T: #server_trait,
152                     B: Body + std::marker::Send + 'static,
153                     B::Error: Into<StdError> + std::marker::Send + 'static,
154             {
155                 type Response = http::Response<tonic::body::Body>;
156                 type Error = std::convert::Infallible;
157                 type Future = BoxFuture<Self::Response, Self::Error>;
158 
159                 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
160                     Poll::Ready(Ok(()))
161                 }
162 
163                 fn call(&mut self, req: http::Request<B>) -> Self::Future {
164                     match req.uri().path() {
165                         #methods
166 
167                         _ => Box::pin(async move {
168                             let mut response = http::Response::new(tonic::body::Body::default());
169                             let headers = response.headers_mut();
170                             headers.insert(tonic::Status::GRPC_STATUS, (tonic::Code::Unimplemented as i32).into());
171                             headers.insert(http::header::CONTENT_TYPE, tonic::metadata::GRPC_CONTENT_TYPE);
172                             Ok(response)
173                         }),
174                     }
175                 }
176             }
177 
178             impl<T> Clone for #server_service<T> {
179                 fn clone(&self) -> Self {
180                     let inner = self.inner.clone();
181                     Self {
182                         inner,
183                         accept_compression_encodings: self.accept_compression_encodings,
184                         send_compression_encodings: self.send_compression_encodings,
185                         max_decoding_message_size: self.max_decoding_message_size,
186                         max_encoding_message_size: self.max_encoding_message_size,
187                     }
188                 }
189             }
190 
191             #named
192         }
193     }
194 }
195 
196 #[allow(clippy::too_many_arguments)]
generate_trait<T: Service>( service: &T, emit_package: bool, proto_path: &str, compile_well_known_types: bool, server_trait: Ident, disable_comments: &HashSet<String>, use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream197 fn generate_trait<T: Service>(
198     service: &T,
199     emit_package: bool,
200     proto_path: &str,
201     compile_well_known_types: bool,
202     server_trait: Ident,
203     disable_comments: &HashSet<String>,
204     use_arc_self: bool,
205     generate_default_stubs: bool,
206 ) -> TokenStream {
207     let methods = generate_trait_methods(
208         service,
209         emit_package,
210         proto_path,
211         compile_well_known_types,
212         disable_comments,
213         use_arc_self,
214         generate_default_stubs,
215     );
216     let trait_doc = generate_doc_comment(format!(
217         " Generated trait containing gRPC methods that should be implemented for use with {}Server.",
218         service.name()
219     ));
220 
221     quote! {
222         #trait_doc
223         #[async_trait]
224         pub trait #server_trait : std::marker::Send + std::marker::Sync + 'static {
225             #methods
226         }
227     }
228 }
229 
generate_trait_methods<T: Service>( service: &T, emit_package: bool, proto_path: &str, compile_well_known_types: bool, disable_comments: &HashSet<String>, use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream230 fn generate_trait_methods<T: Service>(
231     service: &T,
232     emit_package: bool,
233     proto_path: &str,
234     compile_well_known_types: bool,
235     disable_comments: &HashSet<String>,
236     use_arc_self: bool,
237     generate_default_stubs: bool,
238 ) -> TokenStream {
239     let mut stream = TokenStream::new();
240 
241     for method in service.methods() {
242         let name = quote::format_ident!("{}", method.name());
243 
244         let (req_message, res_message) =
245             method.request_response_name(proto_path, compile_well_known_types);
246 
247         let method_doc =
248             if disable_comments.contains(&format_method_name(service, method, emit_package)) {
249                 TokenStream::new()
250             } else {
251                 generate_doc_comments(method.comment())
252             };
253 
254         let self_param = if use_arc_self {
255             quote!(self: std::sync::Arc<Self>)
256         } else {
257             quote!(&self)
258         };
259 
260         let method = match (
261             method.client_streaming(),
262             method.server_streaming(),
263             generate_default_stubs,
264         ) {
265             (false, false, true) => {
266                 quote! {
267                     #method_doc
268                     async fn #name(#self_param, request: tonic::Request<#req_message>)
269                         -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
270                         Err(tonic::Status::unimplemented("Not yet implemented"))
271                     }
272                 }
273             }
274             (false, false, false) => {
275                 quote! {
276                     #method_doc
277                     async fn #name(#self_param, request: tonic::Request<#req_message>)
278                         -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
279                 }
280             }
281             (true, false, true) => {
282                 quote! {
283                     #method_doc
284                     async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
285                         -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
286                         Err(tonic::Status::unimplemented("Not yet implemented"))
287                     }
288                 }
289             }
290             (true, false, false) => {
291                 quote! {
292                     #method_doc
293                     async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
294                         -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
295                 }
296             }
297             (false, true, true) => {
298                 quote! {
299                     #method_doc
300                     async fn #name(#self_param, request: tonic::Request<#req_message>)
301                         -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
302                         Err(tonic::Status::unimplemented("Not yet implemented"))
303                     }
304                 }
305             }
306             (false, true, false) => {
307                 let stream = quote::format_ident!("{}Stream", method.identifier());
308                 let stream_doc = generate_doc_comment(format!(
309                     " Server streaming response type for the {} method.",
310                     method.identifier()
311                 ));
312 
313                 quote! {
314                     #stream_doc
315                     type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
316 
317                     #method_doc
318                     async fn #name(#self_param, request: tonic::Request<#req_message>)
319                         -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
320                 }
321             }
322             (true, true, true) => {
323                 quote! {
324                     #method_doc
325                     async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
326                         -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
327                         Err(tonic::Status::unimplemented("Not yet implemented"))
328                     }
329                 }
330             }
331             (true, true, false) => {
332                 let stream = quote::format_ident!("{}Stream", method.identifier());
333                 let stream_doc = generate_doc_comment(format!(
334                     " Server streaming response type for the {} method.",
335                     method.identifier()
336                 ));
337 
338                 quote! {
339                     #stream_doc
340                     type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
341 
342                     #method_doc
343                     async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
344                         -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
345                 }
346             }
347         };
348 
349         stream.extend(method);
350     }
351 
352     stream
353 }
354 
generate_named(server_service: &syn::Ident, service_name: &str) -> TokenStream355 fn generate_named(server_service: &syn::Ident, service_name: &str) -> TokenStream {
356     let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site());
357     let name_doc = generate_doc_comment(" Generated gRPC service name");
358 
359     quote! {
360         #name_doc
361         pub const SERVICE_NAME: &str = #service_name;
362 
363         impl<T> tonic::server::NamedService for #server_service<T> {
364             const NAME: &'static str = SERVICE_NAME;
365         }
366     }
367 }
368 
generate_methods<T: Service>( service: &T, emit_package: bool, proto_path: &str, compile_well_known_types: bool, use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream369 fn generate_methods<T: Service>(
370     service: &T,
371     emit_package: bool,
372     proto_path: &str,
373     compile_well_known_types: bool,
374     use_arc_self: bool,
375     generate_default_stubs: bool,
376 ) -> TokenStream {
377     let mut stream = TokenStream::new();
378 
379     for method in service.methods() {
380         let path = format_method_path(service, method, emit_package);
381         let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
382         let ident = quote::format_ident!("{}", method.name());
383         let server_trait = quote::format_ident!("{}", service.name());
384 
385         let method_stream = match (method.client_streaming(), method.server_streaming()) {
386             (false, false) => generate_unary(
387                 method,
388                 proto_path,
389                 compile_well_known_types,
390                 ident,
391                 server_trait,
392                 use_arc_self,
393             ),
394 
395             (false, true) => generate_server_streaming(
396                 method,
397                 proto_path,
398                 compile_well_known_types,
399                 ident.clone(),
400                 server_trait,
401                 use_arc_self,
402                 generate_default_stubs,
403             ),
404             (true, false) => generate_client_streaming(
405                 method,
406                 proto_path,
407                 compile_well_known_types,
408                 ident.clone(),
409                 server_trait,
410                 use_arc_self,
411             ),
412 
413             (true, true) => generate_streaming(
414                 method,
415                 proto_path,
416                 compile_well_known_types,
417                 ident.clone(),
418                 server_trait,
419                 use_arc_self,
420                 generate_default_stubs,
421             ),
422         };
423 
424         let method = quote! {
425             #method_path => {
426                 #method_stream
427             }
428         };
429         stream.extend(method);
430     }
431 
432     stream
433 }
434 
generate_unary<T: Method>( method: &T, proto_path: &str, compile_well_known_types: bool, method_ident: Ident, server_trait: Ident, use_arc_self: bool, ) -> TokenStream435 fn generate_unary<T: Method>(
436     method: &T,
437     proto_path: &str,
438     compile_well_known_types: bool,
439     method_ident: Ident,
440     server_trait: Ident,
441     use_arc_self: bool,
442 ) -> TokenStream {
443     let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
444 
445     let service_ident = quote::format_ident!("{}Svc", method.identifier());
446 
447     let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
448 
449     let inner_arg = if use_arc_self {
450         quote!(inner)
451     } else {
452         quote!(&inner)
453     };
454 
455     quote! {
456         #[allow(non_camel_case_types)]
457         struct #service_ident<T: #server_trait >(pub Arc<T>);
458 
459         impl<T: #server_trait> tonic::server::UnaryService<#request> for #service_ident<T> {
460             type Response = #response;
461             type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
462 
463             fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
464                 let inner = Arc::clone(&self.0);
465                 let fut = async move {
466                     <T as #server_trait>::#method_ident(#inner_arg, request).await
467                 };
468                 Box::pin(fut)
469             }
470         }
471 
472         let accept_compression_encodings = self.accept_compression_encodings;
473         let send_compression_encodings = self.send_compression_encodings;
474         let max_decoding_message_size = self.max_decoding_message_size;
475         let max_encoding_message_size = self.max_encoding_message_size;
476         let inner = self.inner.clone();
477         let fut = async move {
478             let method = #service_ident(inner);
479             let codec = #codec_name::default();
480 
481             let mut grpc = tonic::server::Grpc::new(codec)
482                 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
483                 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
484 
485             let res = grpc.unary(method, req).await;
486             Ok(res)
487         };
488 
489         Box::pin(fut)
490     }
491 }
492 
generate_server_streaming<T: Method>( method: &T, proto_path: &str, compile_well_known_types: bool, method_ident: Ident, server_trait: Ident, use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream493 fn generate_server_streaming<T: Method>(
494     method: &T,
495     proto_path: &str,
496     compile_well_known_types: bool,
497     method_ident: Ident,
498     server_trait: Ident,
499     use_arc_self: bool,
500     generate_default_stubs: bool,
501 ) -> TokenStream {
502     let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
503 
504     let service_ident = quote::format_ident!("{}Svc", method.identifier());
505 
506     let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
507 
508     let response_stream = if !generate_default_stubs {
509         let stream = quote::format_ident!("{}Stream", method.identifier());
510         quote!(type ResponseStream = T::#stream)
511     } else {
512         quote!(type ResponseStream = BoxStream<#response>)
513     };
514 
515     let inner_arg = if use_arc_self {
516         quote!(inner)
517     } else {
518         quote!(&inner)
519     };
520 
521     quote! {
522         #[allow(non_camel_case_types)]
523         struct #service_ident<T: #server_trait >(pub Arc<T>);
524 
525         impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
526             type Response = #response;
527             #response_stream;
528             type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
529 
530             fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
531                 let inner = Arc::clone(&self.0);
532                 let fut = async move {
533                     <T as #server_trait>::#method_ident(#inner_arg, request).await
534                 };
535                 Box::pin(fut)
536             }
537         }
538 
539         let accept_compression_encodings = self.accept_compression_encodings;
540         let send_compression_encodings = self.send_compression_encodings;
541         let max_decoding_message_size = self.max_decoding_message_size;
542         let max_encoding_message_size = self.max_encoding_message_size;
543         let inner = self.inner.clone();
544         let fut = async move {
545             let method = #service_ident(inner);
546             let codec = #codec_name::default();
547 
548             let mut grpc = tonic::server::Grpc::new(codec)
549                 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
550                 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
551 
552             let res = grpc.server_streaming(method, req).await;
553             Ok(res)
554         };
555 
556         Box::pin(fut)
557     }
558 }
559 
generate_client_streaming<T: Method>( method: &T, proto_path: &str, compile_well_known_types: bool, method_ident: Ident, server_trait: Ident, use_arc_self: bool, ) -> TokenStream560 fn generate_client_streaming<T: Method>(
561     method: &T,
562     proto_path: &str,
563     compile_well_known_types: bool,
564     method_ident: Ident,
565     server_trait: Ident,
566     use_arc_self: bool,
567 ) -> TokenStream {
568     let service_ident = quote::format_ident!("{}Svc", method.identifier());
569 
570     let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
571     let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
572 
573     let inner_arg = if use_arc_self {
574         quote!(inner)
575     } else {
576         quote!(&inner)
577     };
578 
579     quote! {
580         #[allow(non_camel_case_types)]
581         struct #service_ident<T: #server_trait >(pub Arc<T>);
582 
583         impl<T: #server_trait> tonic::server::ClientStreamingService<#request> for #service_ident<T>
584         {
585             type Response = #response;
586             type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
587 
588             fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
589                 let inner = Arc::clone(&self.0);
590                 let fut = async move {
591                     <T as #server_trait>::#method_ident(#inner_arg, request).await
592                 };
593                 Box::pin(fut)
594             }
595         }
596 
597         let accept_compression_encodings = self.accept_compression_encodings;
598         let send_compression_encodings = self.send_compression_encodings;
599         let max_decoding_message_size = self.max_decoding_message_size;
600         let max_encoding_message_size = self.max_encoding_message_size;
601         let inner = self.inner.clone();
602         let fut = async move {
603             let method = #service_ident(inner);
604             let codec = #codec_name::default();
605 
606             let mut grpc = tonic::server::Grpc::new(codec)
607                 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
608                 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
609 
610             let res = grpc.client_streaming(method, req).await;
611             Ok(res)
612         };
613 
614         Box::pin(fut)
615     }
616 }
617 
generate_streaming<T: Method>( method: &T, proto_path: &str, compile_well_known_types: bool, method_ident: Ident, server_trait: Ident, use_arc_self: bool, generate_default_stubs: bool, ) -> TokenStream618 fn generate_streaming<T: Method>(
619     method: &T,
620     proto_path: &str,
621     compile_well_known_types: bool,
622     method_ident: Ident,
623     server_trait: Ident,
624     use_arc_self: bool,
625     generate_default_stubs: bool,
626 ) -> TokenStream {
627     let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
628 
629     let service_ident = quote::format_ident!("{}Svc", method.identifier());
630 
631     let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
632 
633     let response_stream = if !generate_default_stubs {
634         let stream = quote::format_ident!("{}Stream", method.identifier());
635         quote!(type ResponseStream = T::#stream)
636     } else {
637         quote!(type ResponseStream = BoxStream<#response>)
638     };
639 
640     let inner_arg = if use_arc_self {
641         quote!(inner)
642     } else {
643         quote!(&inner)
644     };
645 
646     quote! {
647         #[allow(non_camel_case_types)]
648         struct #service_ident<T: #server_trait>(pub Arc<T>);
649 
650         impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
651         {
652             type Response = #response;
653             #response_stream;
654             type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
655 
656             fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
657                 let inner = Arc::clone(&self.0);
658                 let fut = async move {
659                     <T as #server_trait>::#method_ident(#inner_arg, request).await
660                 };
661                 Box::pin(fut)
662             }
663         }
664 
665         let accept_compression_encodings = self.accept_compression_encodings;
666         let send_compression_encodings = self.send_compression_encodings;
667         let max_decoding_message_size = self.max_decoding_message_size;
668         let max_encoding_message_size = self.max_encoding_message_size;
669         let inner = self.inner.clone();
670         let fut = async move {
671             let method = #service_ident(inner);
672             let codec = #codec_name::default();
673 
674             let mut grpc = tonic::server::Grpc::new(codec)
675                 .apply_compression_config(accept_compression_encodings, send_compression_encodings)
676                 .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
677 
678             let res = grpc.streaming(method, req).await;
679             Ok(res)
680         };
681 
682         Box::pin(fut)
683     }
684 }
685