xref: /tonic/examples/src/streaming/server.rs (revision 5c8a5d70)
1 pub mod pb {
2     tonic::include_proto!("grpc.examples.echo");
3 }
4 
5 use futures::Stream;
6 use std::{error::Error, io::ErrorKind, net::ToSocketAddrs, pin::Pin, time::Duration};
7 use tokio::sync::mpsc;
8 use tokio_stream::{wrappers::ReceiverStream, StreamExt};
9 use tonic::{transport::Server, Request, Response, Status, Streaming};
10 
11 use pb::{EchoRequest, EchoResponse};
12 
13 type EchoResult<T> = Result<Response<T>, Status>;
14 type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send>>;
15 
16 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
17     let mut err: &(dyn Error + 'static) = err_status;
18 
19     loop {
20         if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
21             return Some(io_err);
22         }
23 
24         // h2::Error do not expose std::io::Error with `source()`
25         // https://github.com/hyperium/h2/pull/462
26         if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
27             if let Some(io_err) = h2_err.get_io() {
28                 return Some(io_err);
29             }
30         }
31 
32         err = match err.source() {
33             Some(err) => err,
34             None => return None,
35         };
36     }
37 }
38 
39 #[derive(Debug)]
40 pub struct EchoServer {}
41 
42 #[tonic::async_trait]
43 impl pb::echo_server::Echo for EchoServer {
44     async fn unary_echo(&self, _: Request<EchoRequest>) -> EchoResult<EchoResponse> {
45         Err(Status::unimplemented("not implemented"))
46     }
47 
48     type ServerStreamingEchoStream = ResponseStream;
49 
50     async fn server_streaming_echo(
51         &self,
52         req: Request<EchoRequest>,
53     ) -> EchoResult<Self::ServerStreamingEchoStream> {
54         println!("EchoServer::server_streaming_echo");
55         println!("\tclient connected from: {:?}", req.remote_addr());
56 
57         // creating infinite stream with requested message
58         let repeat = std::iter::repeat(EchoResponse {
59             message: req.into_inner().message,
60         });
61         let mut stream = Box::pin(tokio_stream::iter(repeat).throttle(Duration::from_millis(200)));
62 
63         // spawn and channel are required if you want handle "disconnect" functionality
64         // the `out_stream` will not be polled after client disconnect
65         let (tx, rx) = mpsc::channel(128);
66         tokio::spawn(async move {
67             while let Some(item) = stream.next().await {
68                 match tx.send(Result::<_, Status>::Ok(item)).await {
69                     Ok(_) => {
70                         // item (server response) was queued to be send to client
71                     }
72                     Err(_item) => {
73                         // output_stream was build from rx and both are dropped
74                         break;
75                     }
76                 }
77             }
78             println!("\tclient disconnected");
79         });
80 
81         let output_stream = ReceiverStream::new(rx);
82         Ok(Response::new(
83             Box::pin(output_stream) as Self::ServerStreamingEchoStream
84         ))
85     }
86 
87     async fn client_streaming_echo(
88         &self,
89         _: Request<Streaming<EchoRequest>>,
90     ) -> EchoResult<EchoResponse> {
91         Err(Status::unimplemented("not implemented"))
92     }
93 
94     type BidirectionalStreamingEchoStream = ResponseStream;
95 
96     async fn bidirectional_streaming_echo(
97         &self,
98         req: Request<Streaming<EchoRequest>>,
99     ) -> EchoResult<Self::BidirectionalStreamingEchoStream> {
100         println!("EchoServer::bidirectional_streaming_echo");
101 
102         let mut in_stream = req.into_inner();
103         let (tx, rx) = mpsc::channel(128);
104 
105         // this spawn here is required if you want to handle connection error.
106         // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
107         // will be drooped when connection error occurs and error will never be propagated
108         // to mapped version of `in_stream`.
109         tokio::spawn(async move {
110             while let Some(result) = in_stream.next().await {
111                 match result {
112                     Ok(v) => tx
113                         .send(Ok(EchoResponse { message: v.message }))
114                         .await
115                         .expect("working rx"),
116                     Err(err) => {
117                         if let Some(io_err) = match_for_io_error(&err) {
118                             if io_err.kind() == ErrorKind::BrokenPipe {
119                                 // here you can handle special case when client
120                                 // disconnected in unexpected way
121                                 eprintln!("\tclient disconnected: broken pipe");
122                                 break;
123                             }
124                         }
125 
126                         match tx.send(Err(err)).await {
127                             Ok(_) => (),
128                             Err(_err) => break, // response was droped
129                         }
130                     }
131                 }
132             }
133             println!("\tstream ended");
134         });
135 
136         // echo just write the same data that was received
137         let out_stream = ReceiverStream::new(rx);
138 
139         Ok(Response::new(
140             Box::pin(out_stream) as Self::BidirectionalStreamingEchoStream
141         ))
142     }
143 }
144 
145 #[tokio::main]
146 async fn main() -> Result<(), Box<dyn std::error::Error>> {
147     let server = EchoServer {};
148     Server::builder()
149         .add_service(pb::echo_server::EchoServer::new(server))
150         .serve("[::1]:50051".to_socket_addrs().unwrap().next().unwrap())
151         .await
152         .unwrap();
153 
154     Ok(())
155 }
156