xref: /tonic/examples/src/h2c/server.rs (revision 970635a1)
1 use std::net::SocketAddr;
2 
3 use hyper_util::rt::{TokioExecutor, TokioIo};
4 use hyper_util::server::conn::auto::Builder;
5 use hyper_util::service::TowerToHyperService;
6 use tokio::net::TcpListener;
7 use tonic::{service::Routes, Request, Response, Status};
8 
9 use hello_world::greeter_server::{Greeter, GreeterServer};
10 use hello_world::{HelloReply, HelloRequest};
11 
12 pub mod hello_world {
13     tonic::include_proto!("helloworld");
14 }
15 
16 #[derive(Default)]
17 pub struct MyGreeter {}
18 
19 #[tonic::async_trait]
20 impl Greeter for MyGreeter {
say_hello( &self, request: Request<HelloRequest>, ) -> Result<Response<HelloReply>, Status>21     async fn say_hello(
22         &self,
23         request: Request<HelloRequest>,
24     ) -> Result<Response<HelloReply>, Status> {
25         println!("Got a request from {:?}", request.remote_addr());
26 
27         let reply = hello_world::HelloReply {
28             message: format!("Hello {}!", request.into_inner().name),
29         };
30         Ok(Response::new(reply))
31     }
32 }
33 
34 #[tokio::main]
main() -> Result<(), Box<dyn std::error::Error>>35 async fn main() -> Result<(), Box<dyn std::error::Error>> {
36     let addr: SocketAddr = "[::1]:50051".parse().unwrap();
37     let greeter = MyGreeter::default();
38 
39     println!("GreeterServer listening on {}", addr);
40 
41     let incoming = TcpListener::bind(addr).await?;
42     let svc = Routes::new(GreeterServer::new(greeter)).prepare();
43 
44     let h2c = h2c::H2c { s: svc };
45 
46     loop {
47         match incoming.accept().await {
48             Ok((io, _)) => {
49                 let router = h2c.clone();
50                 tokio::spawn(async move {
51                     let builder = Builder::new(TokioExecutor::new());
52                     let conn = builder.serve_connection_with_upgrades(
53                         TokioIo::new(io),
54                         TowerToHyperService::new(router),
55                     );
56                     let _ = conn.await;
57                 });
58             }
59             Err(e) => {
60                 eprintln!("Error accepting connection: {}", e);
61             }
62         }
63     }
64 }
65 
66 mod h2c {
67     use std::pin::Pin;
68 
69     use http::{Request, Response};
70     use hyper::body::Incoming;
71     use hyper_util::{rt::TokioExecutor, service::TowerToHyperService};
72     use tonic::body::Body;
73     use tower::{Service, ServiceExt};
74 
75     #[derive(Clone)]
76     pub struct H2c<S> {
77         pub s: S,
78     }
79 
80     type BoxError = Box<dyn std::error::Error + Send + Sync>;
81 
82     impl<S> Service<Request<Incoming>> for H2c<S>
83     where
84         S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
85         S::Future: Send,
86         S::Error: Into<BoxError> + 'static,
87     {
88         type Response = hyper::Response<Body>;
89         type Error = hyper::Error;
90         type Future =
91             Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
92 
poll_ready( &mut self, _: &mut std::task::Context<'_>, ) -> std::task::Poll<Result<(), Self::Error>>93         fn poll_ready(
94             &mut self,
95             _: &mut std::task::Context<'_>,
96         ) -> std::task::Poll<Result<(), Self::Error>> {
97             std::task::Poll::Ready(Ok(()))
98         }
99 
call(&mut self, req: hyper::Request<Incoming>) -> Self::Future100         fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
101             let mut req = req.map(Body::new);
102             let svc = self
103                 .s
104                 .clone()
105                 .map_request(|req: Request<_>| req.map(Body::new));
106             Box::pin(async move {
107                 tokio::spawn(async move {
108                     let upgraded_io = hyper::upgrade::on(&mut req).await.unwrap();
109 
110                     hyper::server::conn::http2::Builder::new(TokioExecutor::new())
111                         .serve_connection(upgraded_io, TowerToHyperService::new(svc))
112                         .await
113                         .unwrap();
114                 });
115 
116                 let mut res = hyper::Response::new(Body::default());
117                 *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
118                 res.headers_mut().insert(
119                     hyper::header::UPGRADE,
120                     http::header::HeaderValue::from_static("h2c"),
121                 );
122 
123                 Ok(res)
124             })
125         }
126     }
127 }
128