xref: /tonic/tonic-web/src/service.rs (revision 72b0fd59)
1 use core::fmt;
2 use std::future::Future;
3 use std::pin::Pin;
4 use std::task::{ready, Context, Poll};
5 
6 use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
7 use pin_project::pin_project;
8 use tonic::metadata::GRPC_CONTENT_TYPE;
9 use tonic::{body::Body, server::NamedService};
10 use tower_service::Service;
11 use tracing::{debug, trace};
12 
13 use crate::call::content_types::is_grpc_web;
14 use crate::call::{Encoding, GrpcWebCall};
15 
16 /// Service implementing the grpc-web protocol.
17 #[derive(Debug, Clone)]
18 pub struct GrpcWebService<S> {
19     inner: S,
20 }
21 
22 #[derive(Debug, PartialEq)]
23 enum RequestKind<'a> {
24     // The request is considered a grpc-web request if its `content-type`
25     // header is exactly one of:
26     //
27     //  - "application/grpc-web"
28     //  - "application/grpc-web+proto"
29     //  - "application/grpc-web-text"
30     //  - "application/grpc-web-text+proto"
31     GrpcWeb {
32         method: &'a Method,
33         encoding: Encoding,
34         accept: Encoding,
35     },
36     // All other requests, including `application/grpc`
37     Other(http::Version),
38 }
39 
40 impl<S> GrpcWebService<S> {
new(inner: S) -> Self41     pub(crate) fn new(inner: S) -> Self {
42         GrpcWebService { inner }
43     }
44 }
45 
46 impl<S, B> Service<Request<B>> for GrpcWebService<S>
47 where
48     S: Service<Request<Body>, Response = Response<Body>>,
49     B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
50     B::Error: Into<crate::BoxError> + fmt::Display,
51 {
52     type Response = S::Response;
53     type Error = S::Error;
54     type Future = ResponseFuture<S::Future>;
55 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>56     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57         self.inner.poll_ready(cx)
58     }
59 
call(&mut self, req: Request<B>) -> Self::Future60     fn call(&mut self, req: Request<B>) -> Self::Future {
61         match RequestKind::new(req.headers(), req.method(), req.version()) {
62             // A valid grpc-web request, regardless of HTTP version.
63             //
64             // If the request includes an `origin` header, we verify it is allowed
65             // to access the resource, an HTTP 403 response is returned otherwise.
66             //
67             // If the origin is allowed to access the resource or there is no
68             // `origin` header present, translate the request into a grpc request,
69             // call the inner service, and translate the response back to
70             // grpc-web.
71             RequestKind::GrpcWeb {
72                 method: &Method::POST,
73                 encoding,
74                 accept,
75             } => {
76                 trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
77 
78                 ResponseFuture {
79                     case: Case::GrpcWeb {
80                         future: self.inner.call(coerce_request(req, encoding)),
81                         accept,
82                     },
83                 }
84             }
85 
86             // The request's content-type matches one of the 4 supported grpc-web
87             // content-types, but the request method is not `POST`.
88             // This is not a valid grpc-web request, return HTTP 405.
89             RequestKind::GrpcWeb { .. } => {
90                 debug!(kind = "simple", error="method not allowed", method = ?req.method());
91 
92                 ResponseFuture {
93                     case: Case::immediate(StatusCode::METHOD_NOT_ALLOWED),
94                 }
95             }
96 
97             // All http/2 requests that are not grpc-web are passed through to the inner service,
98             // whatever they are.
99             RequestKind::Other(Version::HTTP_2) => {
100                 debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
101                 ResponseFuture {
102                     case: Case::Other {
103                         future: self.inner.call(req.map(Body::new)),
104                     },
105                 }
106             }
107 
108             // Return HTTP 400 for all other requests.
109             RequestKind::Other(_) => {
110                 debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE));
111 
112                 ResponseFuture {
113                     case: Case::immediate(StatusCode::BAD_REQUEST),
114                 }
115             }
116         }
117     }
118 }
119 
120 /// Response future for the [`GrpcWebService`].
121 #[pin_project]
122 #[must_use = "futures do nothing unless polled"]
123 pub struct ResponseFuture<F> {
124     #[pin]
125     case: Case<F>,
126 }
127 
128 #[pin_project(project = CaseProj)]
129 enum Case<F> {
130     GrpcWeb {
131         #[pin]
132         future: F,
133         accept: Encoding,
134     },
135     Other {
136         #[pin]
137         future: F,
138     },
139     ImmediateResponse {
140         res: Option<http::response::Parts>,
141     },
142 }
143 
144 impl<F> Case<F> {
immediate(status: StatusCode) -> Self145     fn immediate(status: StatusCode) -> Self {
146         let (res, ()) = Response::builder()
147             .status(status)
148             .body(())
149             .unwrap()
150             .into_parts();
151         Self::ImmediateResponse { res: Some(res) }
152     }
153 }
154 
155 impl<F, E> Future for ResponseFuture<F>
156 where
157     F: Future<Output = Result<Response<Body>, E>>,
158 {
159     type Output = Result<Response<Body>, E>;
160 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>161     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162         let this = self.project();
163 
164         match this.case.project() {
165             CaseProj::GrpcWeb { future, accept } => {
166                 let res = ready!(future.poll(cx))?;
167 
168                 Poll::Ready(Ok(coerce_response(res, *accept)))
169             }
170             CaseProj::Other { future } => future.poll(cx),
171             CaseProj::ImmediateResponse { res } => {
172                 let res = Response::from_parts(res.take().unwrap(), Body::empty());
173                 Poll::Ready(Ok(res))
174             }
175         }
176     }
177 }
178 
179 impl<S: NamedService> NamedService for GrpcWebService<S> {
180     const NAME: &'static str = S::NAME;
181 }
182 
183 impl<F> fmt::Debug for ResponseFuture<F> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result184     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185         f.debug_struct("ResponseFuture").finish()
186     }
187 }
188 
189 impl<'a> RequestKind<'a> {
new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self190     fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self {
191         if is_grpc_web(headers) {
192             return RequestKind::GrpcWeb {
193                 method,
194                 encoding: Encoding::from_content_type(headers),
195                 accept: Encoding::from_accept(headers),
196             };
197         }
198 
199         RequestKind::Other(version)
200     }
201 }
202 
203 // Mutating request headers to conform to a gRPC request is not really
204 // necessary for us at this point. We could remove most of these except
205 // maybe for inserting `header::TE`, which tonic should check?
coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body> where B: http_body::Body<Data = bytes::Bytes> + Send + 'static, B::Error: Into<crate::BoxError> + fmt::Display,206 fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
207 where
208     B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
209     B::Error: Into<crate::BoxError> + fmt::Display,
210 {
211     req.headers_mut().remove(header::CONTENT_LENGTH);
212 
213     req.headers_mut()
214         .insert(header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
215 
216     req.headers_mut()
217         .insert(header::TE, HeaderValue::from_static("trailers"));
218 
219     req.headers_mut().insert(
220         header::ACCEPT_ENCODING,
221         HeaderValue::from_static("identity,deflate,gzip"),
222     );
223 
224     req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
225 }
226 
coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body> where B: http_body::Body<Data = bytes::Bytes> + Send + 'static, B::Error: Into<crate::BoxError> + fmt::Display,227 fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
228 where
229     B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
230     B::Error: Into<crate::BoxError> + fmt::Display,
231 {
232     let mut res = res
233         .map(|b| GrpcWebCall::response(b, encoding))
234         .map(Body::new);
235 
236     res.headers_mut().insert(
237         header::CONTENT_TYPE,
238         HeaderValue::from_static(encoding.to_content_type()),
239     );
240 
241     res
242 }
243 
244 #[cfg(test)]
245 mod tests {
246     use super::*;
247     use crate::call::content_types::*;
248     use http::header::{
249         ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
250     };
251     use tower_layer::Layer as _;
252 
253     type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
254 
255     #[derive(Debug, Clone)]
256     struct Svc;
257 
258     impl tower_service::Service<Request<Body>> for Svc {
259         type Response = Response<Body>;
260         type Error = String;
261         type Future = BoxFuture<Self::Response, Self::Error>;
262 
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>>263         fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
264             Poll::Ready(Ok(()))
265         }
266 
call(&mut self, _: Request<Body>) -> Self::Future267         fn call(&mut self, _: Request<Body>) -> Self::Future {
268             Box::pin(async { Ok(Response::new(Body::default())) })
269         }
270     }
271 
272     impl NamedService for Svc {
273         const NAME: &'static str = "test";
274     }
275 
enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>> where S: Service<http::Request<Body>, Response = http::Response<Body>>,276     fn enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>>
277     where
278         S: Service<http::Request<Body>, Response = http::Response<Body>>,
279     {
280         tower_layer::Stack::new(
281             crate::GrpcWebLayer::new(),
282             tower_http::cors::CorsLayer::new(),
283         )
284         .layer(service)
285     }
286 
287     mod grpc_web {
288         use super::*;
289         use tower_layer::Layer;
290 
request() -> Request<Body>291         fn request() -> Request<Body> {
292             Request::builder()
293                 .method(Method::POST)
294                 .header(CONTENT_TYPE, GRPC_WEB)
295                 .header(ORIGIN, "http://example.com")
296                 .body(Body::default())
297                 .unwrap()
298         }
299 
300         #[tokio::test]
default_cors_config()301         async fn default_cors_config() {
302             let mut svc = enable(Svc);
303             let res = svc.call(request()).await.unwrap();
304 
305             assert_eq!(res.status(), StatusCode::OK);
306         }
307 
308         #[tokio::test]
web_layer()309         async fn web_layer() {
310             let mut svc = crate::GrpcWebLayer::new().layer(Svc);
311             let res = svc.call(request()).await.unwrap();
312 
313             assert_eq!(res.status(), StatusCode::OK);
314         }
315 
316         #[tokio::test]
without_origin()317         async fn without_origin() {
318             let mut svc = enable(Svc);
319 
320             let mut req = request();
321             req.headers_mut().remove(ORIGIN);
322 
323             let res = svc.call(req).await.unwrap();
324 
325             assert_eq!(res.status(), StatusCode::OK);
326         }
327 
328         #[tokio::test]
only_post_and_options_allowed()329         async fn only_post_and_options_allowed() {
330             let mut svc = enable(Svc);
331 
332             for method in &[
333                 Method::GET,
334                 Method::PUT,
335                 Method::DELETE,
336                 Method::HEAD,
337                 Method::PATCH,
338             ] {
339                 let mut req = request();
340                 *req.method_mut() = method.clone();
341 
342                 let res = svc.call(req).await.unwrap();
343 
344                 assert_eq!(
345                     res.status(),
346                     StatusCode::METHOD_NOT_ALLOWED,
347                     "{} should not be allowed",
348                     method
349                 );
350             }
351         }
352 
353         #[tokio::test]
grpc_web_content_types()354         async fn grpc_web_content_types() {
355             let mut svc = enable(Svc);
356 
357             for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_TEXT_PROTO, GRPC_WEB] {
358                 let mut req = request();
359                 req.headers_mut()
360                     .insert(CONTENT_TYPE, HeaderValue::from_static(ct));
361 
362                 let res = svc.call(req).await.unwrap();
363 
364                 assert_eq!(res.status(), StatusCode::OK);
365             }
366         }
367     }
368 
369     mod options {
370         use super::*;
371 
request() -> Request<Body>372         fn request() -> Request<Body> {
373             Request::builder()
374                 .method(Method::OPTIONS)
375                 .header(ORIGIN, "http://example.com")
376                 .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web")
377                 .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
378                 .body(Body::default())
379                 .unwrap()
380         }
381 
382         #[tokio::test]
valid_grpc_web_preflight()383         async fn valid_grpc_web_preflight() {
384             let mut svc = enable(Svc);
385             let res = svc.call(request()).await.unwrap();
386 
387             assert_eq!(res.status(), StatusCode::OK);
388         }
389     }
390 
391     mod grpc {
392         use super::*;
393 
request() -> Request<Body>394         fn request() -> Request<Body> {
395             Request::builder()
396                 .version(Version::HTTP_2)
397                 .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
398                 .body(Body::default())
399                 .unwrap()
400         }
401 
402         #[tokio::test]
h2_is_ok()403         async fn h2_is_ok() {
404             let mut svc = enable(Svc);
405 
406             let req = request();
407             let res = svc.call(req).await.unwrap();
408 
409             assert_eq!(res.status(), StatusCode::OK)
410         }
411 
412         #[tokio::test]
h1_is_err()413         async fn h1_is_err() {
414             let mut svc = enable(Svc);
415 
416             let req = Request::builder()
417                 .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
418                 .body(Body::default())
419                 .unwrap();
420 
421             let res = svc.call(req).await.unwrap();
422             assert_eq!(res.status(), StatusCode::BAD_REQUEST)
423         }
424 
425         #[tokio::test]
content_type_variants()426         async fn content_type_variants() {
427             let mut svc = enable(Svc);
428 
429             for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] {
430                 let mut req = request();
431                 req.headers_mut().insert(
432                     CONTENT_TYPE,
433                     HeaderValue::from_maybe_shared(format!("application/{}", variant)).unwrap(),
434                 );
435 
436                 let res = svc.call(req).await.unwrap();
437 
438                 assert_eq!(res.status(), StatusCode::OK)
439             }
440         }
441     }
442 
443     mod other {
444         use super::*;
445 
request() -> Request<Body>446         fn request() -> Request<Body> {
447             Request::builder()
448                 .header(CONTENT_TYPE, "application/text")
449                 .body(Body::default())
450                 .unwrap()
451         }
452 
453         #[tokio::test]
h1_is_err()454         async fn h1_is_err() {
455             let mut svc = enable(Svc);
456             let res = svc.call(request()).await.unwrap();
457 
458             assert_eq!(res.status(), StatusCode::BAD_REQUEST)
459         }
460 
461         #[tokio::test]
h2_is_ok()462         async fn h2_is_ok() {
463             let mut svc = enable(Svc);
464             let mut req = request();
465             *req.version_mut() = Version::HTTP_2;
466 
467             let res = svc.call(req).await.unwrap();
468             assert_eq!(res.status(), StatusCode::OK)
469         }
470     }
471 }
472