1 use bytes::Bytes;
2 use http::Uri;
3 use hyper_util::rt::TokioIo;
4 use integration_tests::mock::MockStream;
5 use integration_tests::pb::{
6     test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output,
7     OutputStream,
8 };
9 use std::error::Error;
10 use std::time::Duration;
11 use tokio::{net::TcpListener, sync::oneshot};
12 use tonic::metadata::{MetadataMap, MetadataValue};
13 use tonic::{
14     transport::{server::TcpIncoming, Endpoint, Server},
15     Code, Request, Response, Status,
16 };
17 
18 #[tokio::test]
status_with_details()19 async fn status_with_details() {
20     struct Svc;
21 
22     #[tonic::async_trait]
23     impl test_server::Test for Svc {
24         async fn unary_call(&self, _: Request<Input>) -> Result<Response<Output>, Status> {
25             Err(Status::with_details(
26                 Code::ResourceExhausted,
27                 "Too many requests",
28                 Bytes::from_static(&[1]),
29             ))
30         }
31     }
32 
33     let svc = test_server::TestServer::new(Svc);
34 
35     let (tx, rx) = oneshot::channel::<()>();
36 
37     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
38     let addr = listener.local_addr().unwrap();
39     let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));
40 
41     let jh = tokio::spawn(async move {
42         Server::builder()
43             .add_service(svc)
44             .serve_with_incoming_shutdown(incoming, async { drop(rx.await) })
45             .await
46             .unwrap();
47     });
48 
49     tokio::time::sleep(Duration::from_millis(100)).await;
50 
51     let mut channel = test_client::TestClient::connect(format!("http://{addr}"))
52         .await
53         .unwrap();
54 
55     let err = channel
56         .unary_call(Request::new(Input {}))
57         .await
58         .unwrap_err();
59 
60     assert_eq!(err.message(), "Too many requests");
61     assert_eq!(err.details(), &[1]);
62 
63     tx.send(()).unwrap();
64 
65     jh.await.unwrap();
66 }
67 
68 #[tokio::test]
status_with_metadata()69 async fn status_with_metadata() {
70     const MESSAGE: &str = "Internal error, see metadata for details";
71 
72     const ASCII_NAME: &str = "x-host-ip";
73     const ASCII_VALUE: &str = "127.0.0.1";
74 
75     const BINARY_NAME: &str = "x-host-name-bin";
76     const BINARY_VALUE: &[u8] = b"localhost";
77 
78     struct Svc;
79 
80     #[tonic::async_trait]
81     impl test_server::Test for Svc {
82         async fn unary_call(&self, _: Request<Input>) -> Result<Response<Output>, Status> {
83             let mut metadata = MetadataMap::new();
84             metadata.insert(ASCII_NAME, ASCII_VALUE.parse().unwrap());
85             metadata.insert_bin(BINARY_NAME, MetadataValue::from_bytes(BINARY_VALUE));
86 
87             Err(Status::with_metadata(Code::Internal, MESSAGE, metadata))
88         }
89     }
90 
91     let svc = test_server::TestServer::new(Svc);
92 
93     let (tx, rx) = oneshot::channel::<()>();
94 
95     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
96     let addr = listener.local_addr().unwrap();
97     let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));
98 
99     let jh = tokio::spawn(async move {
100         Server::builder()
101             .add_service(svc)
102             .serve_with_incoming_shutdown(incoming, async { drop(rx.await) })
103             .await
104             .unwrap();
105     });
106 
107     tokio::time::sleep(Duration::from_millis(100)).await;
108 
109     let mut channel = test_client::TestClient::connect(format!("http://{addr}"))
110         .await
111         .unwrap();
112 
113     let err = channel
114         .unary_call(Request::new(Input {}))
115         .await
116         .unwrap_err();
117 
118     assert_eq!(err.code(), Code::Internal);
119     assert_eq!(err.message(), MESSAGE);
120 
121     let metadata = err.metadata();
122 
123     assert_eq!(
124         metadata.get(ASCII_NAME).unwrap().to_str().unwrap(),
125         ASCII_VALUE
126     );
127 
128     assert_eq!(
129         metadata.get_bin(BINARY_NAME).unwrap().to_bytes().unwrap(),
130         BINARY_VALUE
131     );
132 
133     tx.send(()).unwrap();
134 
135     jh.await.unwrap();
136 }
137 
138 type Stream<T> = std::pin::Pin<
139     Box<dyn tokio_stream::Stream<Item = std::result::Result<T, Status>> + Send + 'static>,
140 >;
141 
142 #[tokio::test]
status_from_server_stream()143 async fn status_from_server_stream() {
144     integration_tests::trace_init();
145 
146     struct Svc;
147 
148     #[tonic::async_trait]
149     impl test_stream_server::TestStream for Svc {
150         type StreamCallStream = Stream<OutputStream>;
151 
152         async fn stream_call(
153             &self,
154             _: Request<InputStream>,
155         ) -> Result<Response<Self::StreamCallStream>, Status> {
156             let s = tokio_stream::iter(vec![
157                 Err::<OutputStream, _>(Status::unavailable("foo")),
158                 Err::<OutputStream, _>(Status::unavailable("bar")),
159             ]);
160             Ok(Response::new(Box::pin(s) as Self::StreamCallStream))
161         }
162     }
163 
164     let svc = test_stream_server::TestStreamServer::new(Svc);
165 
166     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
167     let addr = listener.local_addr().unwrap();
168     let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));
169 
170     tokio::spawn(async move {
171         Server::builder()
172             .add_service(svc)
173             .serve_with_incoming(incoming)
174             .await
175             .unwrap();
176     });
177 
178     tokio::time::sleep(Duration::from_millis(100)).await;
179 
180     let mut client = test_stream_client::TestStreamClient::connect(format!("http://{addr}"))
181         .await
182         .unwrap();
183 
184     let mut stream = client
185         .stream_call(InputStream {})
186         .await
187         .unwrap()
188         .into_inner();
189 
190     assert_eq!(stream.message().await.unwrap_err().message(), "foo");
191     assert_eq!(stream.message().await.unwrap(), None);
192 }
193 
194 #[tokio::test]
status_from_server_stream_with_source()195 async fn status_from_server_stream_with_source() {
196     integration_tests::trace_init();
197 
198     let channel = Endpoint::try_from("http://[::]:50051")
199         .unwrap()
200         .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move {
201             Err::<TokioIo<MockStream>, _>(std::io::Error::other("WTF"))
202         }));
203 
204     let mut client = test_stream_client::TestStreamClient::new(channel);
205 
206     let error = client.stream_call(InputStream {}).await.unwrap_err();
207 
208     let source = error.source().unwrap();
209     source.downcast_ref::<tonic::transport::Error>().unwrap();
210 }
211 
212 #[tokio::test]
message_and_then_status_from_server_stream()213 async fn message_and_then_status_from_server_stream() {
214     integration_tests::trace_init();
215 
216     struct Svc;
217 
218     #[tonic::async_trait]
219     impl test_stream_server::TestStream for Svc {
220         type StreamCallStream = Stream<OutputStream>;
221 
222         async fn stream_call(
223             &self,
224             _: Request<InputStream>,
225         ) -> Result<Response<Self::StreamCallStream>, Status> {
226             let s = tokio_stream::iter(vec![
227                 Ok(OutputStream {}),
228                 Err::<OutputStream, _>(Status::unavailable("foo")),
229             ]);
230             Ok(Response::new(Box::pin(s) as Self::StreamCallStream))
231         }
232     }
233 
234     let svc = test_stream_server::TestStreamServer::new(Svc);
235 
236     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
237     let addr = listener.local_addr().unwrap();
238     let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));
239 
240     tokio::spawn(async move {
241         Server::builder()
242             .add_service(svc)
243             .serve_with_incoming(incoming)
244             .await
245             .unwrap();
246     });
247 
248     tokio::time::sleep(Duration::from_millis(100)).await;
249 
250     let mut client = test_stream_client::TestStreamClient::connect(format!("http://{addr}"))
251         .await
252         .unwrap();
253 
254     let mut stream = client
255         .stream_call(InputStream {})
256         .await
257         .unwrap()
258         .into_inner();
259 
260     assert_eq!(stream.message().await.unwrap(), Some(OutputStream {}));
261     assert_eq!(stream.message().await.unwrap_err().message(), "foo");
262     assert_eq!(stream.message().await.unwrap(), None);
263 }
264