xref: /tonic/examples/src/h2c/server.rs (revision ff71e893)
1 use tonic::{transport::Server, Request, Response, Status};
2 
3 use hello_world::greeter_server::{Greeter, GreeterServer};
4 use hello_world::{HelloReply, HelloRequest};
5 use tower::make::Shared;
6 
7 pub mod hello_world {
8     tonic::include_proto!("helloworld");
9 }
10 
11 #[derive(Default)]
12 pub struct MyGreeter {}
13 
14 #[tonic::async_trait]
15 impl Greeter for MyGreeter {
16     async fn say_hello(
17         &self,
18         request: Request<HelloRequest>,
19     ) -> Result<Response<HelloReply>, Status> {
20         println!("Got a request from {:?}", request.remote_addr());
21 
22         let reply = hello_world::HelloReply {
23             message: format!("Hello {}!", request.into_inner().name),
24         };
25         Ok(Response::new(reply))
26     }
27 }
28 
29 #[tokio::main]
30 async fn main() -> Result<(), Box<dyn std::error::Error>> {
31     let addr = "[::1]:50051".parse().unwrap();
32     let greeter = MyGreeter::default();
33 
34     println!("GreeterServer listening on {}", addr);
35 
36     let svc = Server::builder()
37         .add_service(GreeterServer::new(greeter))
38         .into_router();
39 
40     let h2c = h2c::H2c { s: svc };
41 
42     let server = hyper::Server::bind(&addr).serve(Shared::new(h2c));
43     server.await.unwrap();
44 
45     Ok(())
46 }
47 
48 mod h2c {
49     use std::pin::Pin;
50 
51     use http::{Request, Response};
52     use hyper::Body;
53     use tower::Service;
54 
55     #[derive(Clone)]
56     pub struct H2c<S> {
57         pub s: S,
58     }
59 
60     type BoxError = Box<dyn std::error::Error + Send + Sync>;
61 
62     impl<S> Service<Request<Body>> for H2c<S>
63     where
64         S: Service<Request<Body>, Response = Response<tonic::transport::AxumBoxBody>>
65             + Clone
66             + Send
67             + 'static,
68         S::Future: Send + 'static,
69         S::Error: Into<BoxError> + Sync + Send + 'static,
70         S::Response: Send + 'static,
71     {
72         type Response = hyper::Response<Body>;
73         type Error = hyper::Error;
74         type Future =
75             Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
76 
77         fn poll_ready(
78             &mut self,
79             _: &mut std::task::Context<'_>,
80         ) -> std::task::Poll<Result<(), Self::Error>> {
81             std::task::Poll::Ready(Ok(()))
82         }
83 
84         fn call(&mut self, mut req: hyper::Request<Body>) -> Self::Future {
85             let svc = self.s.clone();
86             Box::pin(async move {
87                 tokio::spawn(async move {
88                     let upgraded_io = hyper::upgrade::on(&mut req).await.unwrap();
89 
90                     hyper::server::conn::Http::new()
91                         .http2_only(true)
92                         .serve_connection(upgraded_io, svc)
93                         .await
94                         .unwrap();
95                 });
96 
97                 let mut res = hyper::Response::new(hyper::Body::empty());
98                 *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
99                 res.headers_mut().insert(
100                     hyper::header::UPGRADE,
101                     http::header::HeaderValue::from_static("h2c"),
102                 );
103 
104                 Ok(res)
105             })
106         }
107     }
108 }
109