xref: /tonic/examples/src/streaming/server.rs (revision dfa08932)
1 pub mod pb {
2     tonic::include_proto!("grpc.examples.echo");
3 }
4 
5 use std::{error::Error, io::ErrorKind, net::ToSocketAddrs, pin::Pin, time::Duration};
6 use tokio::sync::mpsc;
7 use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
8 use tonic::{transport::Server, Request, Response, Status, Streaming};
9 
10 use pb::{EchoRequest, EchoResponse};
11 
12 type EchoResult<T> = Result<Response<T>, Status>;
13 type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send>>;
14 
15 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
16     let mut err: &(dyn Error + 'static) = err_status;
17 
18     loop {
19         if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
20             return Some(io_err);
21         }
22 
23         // h2::Error do not expose std::io::Error with `source()`
24         // https://github.com/hyperium/h2/pull/462
25         if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
26             if let Some(io_err) = h2_err.get_io() {
27                 return Some(io_err);
28             }
29         }
30 
31         err = match err.source() {
32             Some(err) => err,
33             None => return None,
34         };
35     }
36 }
37 
38 #[derive(Debug)]
39 pub struct EchoServer {}
40 
41 #[tonic::async_trait]
42 impl pb::echo_server::Echo for EchoServer {
43     async fn unary_echo(&self, _: Request<EchoRequest>) -> EchoResult<EchoResponse> {
44         Err(Status::unimplemented("not implemented"))
45     }
46 
47     type ServerStreamingEchoStream = ResponseStream;
48 
49     async fn server_streaming_echo(
50         &self,
51         req: Request<EchoRequest>,
52     ) -> EchoResult<Self::ServerStreamingEchoStream> {
53         println!("EchoServer::server_streaming_echo");
54         println!("\tclient connected from: {:?}", req.remote_addr());
55 
56         // creating infinite stream with requested message
57         let repeat = std::iter::repeat(EchoResponse {
58             message: req.into_inner().message,
59         });
60         let mut stream = Box::pin(tokio_stream::iter(repeat).throttle(Duration::from_millis(200)));
61 
62         // spawn and channel are required if you want handle "disconnect" functionality
63         // the `out_stream` will not be polled after client disconnect
64         let (tx, rx) = mpsc::channel(128);
65         tokio::spawn(async move {
66             while let Some(item) = stream.next().await {
67                 match tx.send(Result::<_, Status>::Ok(item)).await {
68                     Ok(_) => {
69                         // item (server response) was queued to be send to client
70                     }
71                     Err(_item) => {
72                         // output_stream was build from rx and both are dropped
73                         break;
74                     }
75                 }
76             }
77             println!("\tclient disconnected");
78         });
79 
80         let output_stream = ReceiverStream::new(rx);
81         Ok(Response::new(
82             Box::pin(output_stream) as Self::ServerStreamingEchoStream
83         ))
84     }
85 
86     async fn client_streaming_echo(
87         &self,
88         _: Request<Streaming<EchoRequest>>,
89     ) -> EchoResult<EchoResponse> {
90         Err(Status::unimplemented("not implemented"))
91     }
92 
93     type BidirectionalStreamingEchoStream = ResponseStream;
94 
95     async fn bidirectional_streaming_echo(
96         &self,
97         req: Request<Streaming<EchoRequest>>,
98     ) -> EchoResult<Self::BidirectionalStreamingEchoStream> {
99         println!("EchoServer::bidirectional_streaming_echo");
100 
101         let mut in_stream = req.into_inner();
102         let (tx, rx) = mpsc::channel(128);
103 
104         // this spawn here is required if you want to handle connection error.
105         // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
106         // will be dropped when connection error occurs and error will never be propagated
107         // to mapped version of `in_stream`.
108         tokio::spawn(async move {
109             while let Some(result) = in_stream.next().await {
110                 match result {
111                     Ok(v) => tx
112                         .send(Ok(EchoResponse { message: v.message }))
113                         .await
114                         .expect("working rx"),
115                     Err(err) => {
116                         if let Some(io_err) = match_for_io_error(&err) {
117                             if io_err.kind() == ErrorKind::BrokenPipe {
118                                 // here you can handle special case when client
119                                 // disconnected in unexpected way
120                                 eprintln!("\tclient disconnected: broken pipe");
121                                 break;
122                             }
123                         }
124 
125                         match tx.send(Err(err)).await {
126                             Ok(_) => (),
127                             Err(_err) => break, // response was droped
128                         }
129                     }
130                 }
131             }
132             println!("\tstream ended");
133         });
134 
135         // echo just write the same data that was received
136         let out_stream = ReceiverStream::new(rx);
137 
138         Ok(Response::new(
139             Box::pin(out_stream) as Self::BidirectionalStreamingEchoStream
140         ))
141     }
142 }
143 
144 #[tokio::main]
145 async fn main() -> Result<(), Box<dyn std::error::Error>> {
146     let server = EchoServer {};
147     Server::builder()
148         .add_service(pb::echo_server::EchoServer::new(server))
149         .serve("[::1]:50051".to_socket_addrs().unwrap().next().unwrap())
150         .await
151         .unwrap();
152 
153     Ok(())
154 }
155