1 pub mod pb {
2 tonic::include_proto!("grpc.examples.echo");
3 }
4
5 use std::{error::Error, io::ErrorKind, net::ToSocketAddrs, pin::Pin, time::Duration};
6 use tokio::sync::mpsc;
7 use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
8 use tonic::{transport::Server, Request, Response, Status, Streaming};
9
10 use pb::{EchoRequest, EchoResponse};
11
12 type EchoResult<T> = Result<Response<T>, Status>;
13 type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send>>;
14
match_for_io_error(err_status: &Status) -> Option<&std::io::Error>15 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
16 let mut err: &(dyn Error + 'static) = err_status;
17
18 loop {
19 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
20 return Some(io_err);
21 }
22
23 // h2::Error do not expose std::io::Error with `source()`
24 // https://github.com/hyperium/h2/pull/462
25 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
26 if let Some(io_err) = h2_err.get_io() {
27 return Some(io_err);
28 }
29 }
30
31 err = err.source()?;
32 }
33 }
34
35 #[derive(Debug)]
36 pub struct EchoServer {}
37
38 #[tonic::async_trait]
39 impl pb::echo_server::Echo for EchoServer {
unary_echo(&self, _: Request<EchoRequest>) -> EchoResult<EchoResponse>40 async fn unary_echo(&self, _: Request<EchoRequest>) -> EchoResult<EchoResponse> {
41 Err(Status::unimplemented("not implemented"))
42 }
43
44 type ServerStreamingEchoStream = ResponseStream;
45
server_streaming_echo( &self, req: Request<EchoRequest>, ) -> EchoResult<Self::ServerStreamingEchoStream>46 async fn server_streaming_echo(
47 &self,
48 req: Request<EchoRequest>,
49 ) -> EchoResult<Self::ServerStreamingEchoStream> {
50 println!("EchoServer::server_streaming_echo");
51 println!("\tclient connected from: {:?}", req.remote_addr());
52
53 // creating infinite stream with requested message
54 let repeat = std::iter::repeat(EchoResponse {
55 message: req.into_inner().message,
56 });
57 let mut stream = Box::pin(tokio_stream::iter(repeat).throttle(Duration::from_millis(200)));
58
59 // spawn and channel are required if you want handle "disconnect" functionality
60 // the `out_stream` will not be polled after client disconnect
61 let (tx, rx) = mpsc::channel(128);
62 tokio::spawn(async move {
63 while let Some(item) = stream.next().await {
64 match tx.send(Result::<_, Status>::Ok(item)).await {
65 Ok(_) => {
66 // item (server response) was queued to be send to client
67 }
68 Err(_item) => {
69 // output_stream was build from rx and both are dropped
70 break;
71 }
72 }
73 }
74 println!("\tclient disconnected");
75 });
76
77 let output_stream = ReceiverStream::new(rx);
78 Ok(Response::new(
79 Box::pin(output_stream) as Self::ServerStreamingEchoStream
80 ))
81 }
82
client_streaming_echo( &self, _: Request<Streaming<EchoRequest>>, ) -> EchoResult<EchoResponse>83 async fn client_streaming_echo(
84 &self,
85 _: Request<Streaming<EchoRequest>>,
86 ) -> EchoResult<EchoResponse> {
87 Err(Status::unimplemented("not implemented"))
88 }
89
90 type BidirectionalStreamingEchoStream = ResponseStream;
91
bidirectional_streaming_echo( &self, req: Request<Streaming<EchoRequest>>, ) -> EchoResult<Self::BidirectionalStreamingEchoStream>92 async fn bidirectional_streaming_echo(
93 &self,
94 req: Request<Streaming<EchoRequest>>,
95 ) -> EchoResult<Self::BidirectionalStreamingEchoStream> {
96 println!("EchoServer::bidirectional_streaming_echo");
97
98 let mut in_stream = req.into_inner();
99 let (tx, rx) = mpsc::channel(128);
100
101 // this spawn here is required if you want to handle connection error.
102 // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
103 // will be dropped when connection error occurs and error will never be propagated
104 // to mapped version of `in_stream`.
105 tokio::spawn(async move {
106 while let Some(result) = in_stream.next().await {
107 match result {
108 Ok(v) => tx
109 .send(Ok(EchoResponse { message: v.message }))
110 .await
111 .expect("working rx"),
112 Err(err) => {
113 if let Some(io_err) = match_for_io_error(&err) {
114 if io_err.kind() == ErrorKind::BrokenPipe {
115 // here you can handle special case when client
116 // disconnected in unexpected way
117 eprintln!("\tclient disconnected: broken pipe");
118 break;
119 }
120 }
121
122 match tx.send(Err(err)).await {
123 Ok(_) => (),
124 Err(_err) => break, // response was dropped
125 }
126 }
127 }
128 }
129 println!("\tstream ended");
130 });
131
132 // echo just write the same data that was received
133 let out_stream = ReceiverStream::new(rx);
134
135 Ok(Response::new(
136 Box::pin(out_stream) as Self::BidirectionalStreamingEchoStream
137 ))
138 }
139 }
140
141 #[tokio::main]
main() -> Result<(), Box<dyn std::error::Error>>142 async fn main() -> Result<(), Box<dyn std::error::Error>> {
143 let server = EchoServer {};
144 Server::builder()
145 .add_service(pb::echo_server::EchoServer::new(server))
146 .serve("[::1]:50051".to_socket_addrs().unwrap().next().unwrap())
147 .await
148 .unwrap();
149
150 Ok(())
151 }
152