1 use core::fmt;
2 use std::future::Future;
3 use std::pin::Pin;
4 use std::task::{ready, Context, Poll};
5
6 use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
7 use pin_project::pin_project;
8 use tonic::metadata::GRPC_CONTENT_TYPE;
9 use tonic::{body::Body, server::NamedService};
10 use tower_service::Service;
11 use tracing::{debug, trace};
12
13 use crate::call::content_types::is_grpc_web;
14 use crate::call::{Encoding, GrpcWebCall};
15
16 /// Service implementing the grpc-web protocol.
17 #[derive(Debug, Clone)]
18 pub struct GrpcWebService<S> {
19 inner: S,
20 }
21
22 #[derive(Debug, PartialEq)]
23 enum RequestKind<'a> {
24 // The request is considered a grpc-web request if its `content-type`
25 // header is exactly one of:
26 //
27 // - "application/grpc-web"
28 // - "application/grpc-web+proto"
29 // - "application/grpc-web-text"
30 // - "application/grpc-web-text+proto"
31 GrpcWeb {
32 method: &'a Method,
33 encoding: Encoding,
34 accept: Encoding,
35 },
36 // All other requests, including `application/grpc`
37 Other(http::Version),
38 }
39
40 impl<S> GrpcWebService<S> {
new(inner: S) -> Self41 pub(crate) fn new(inner: S) -> Self {
42 GrpcWebService { inner }
43 }
44 }
45
46 impl<S, B> Service<Request<B>> for GrpcWebService<S>
47 where
48 S: Service<Request<Body>, Response = Response<Body>>,
49 B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
50 B::Error: Into<crate::BoxError> + fmt::Display,
51 {
52 type Response = S::Response;
53 type Error = S::Error;
54 type Future = ResponseFuture<S::Future>;
55
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>56 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57 self.inner.poll_ready(cx)
58 }
59
call(&mut self, req: Request<B>) -> Self::Future60 fn call(&mut self, req: Request<B>) -> Self::Future {
61 match RequestKind::new(req.headers(), req.method(), req.version()) {
62 // A valid grpc-web request, regardless of HTTP version.
63 //
64 // If the request includes an `origin` header, we verify it is allowed
65 // to access the resource, an HTTP 403 response is returned otherwise.
66 //
67 // If the origin is allowed to access the resource or there is no
68 // `origin` header present, translate the request into a grpc request,
69 // call the inner service, and translate the response back to
70 // grpc-web.
71 RequestKind::GrpcWeb {
72 method: &Method::POST,
73 encoding,
74 accept,
75 } => {
76 trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
77
78 ResponseFuture {
79 case: Case::GrpcWeb {
80 future: self.inner.call(coerce_request(req, encoding)),
81 accept,
82 },
83 }
84 }
85
86 // The request's content-type matches one of the 4 supported grpc-web
87 // content-types, but the request method is not `POST`.
88 // This is not a valid grpc-web request, return HTTP 405.
89 RequestKind::GrpcWeb { .. } => {
90 debug!(kind = "simple", error="method not allowed", method = ?req.method());
91
92 ResponseFuture {
93 case: Case::immediate(StatusCode::METHOD_NOT_ALLOWED),
94 }
95 }
96
97 // All http/2 requests that are not grpc-web are passed through to the inner service,
98 // whatever they are.
99 RequestKind::Other(Version::HTTP_2) => {
100 debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
101 ResponseFuture {
102 case: Case::Other {
103 future: self.inner.call(req.map(Body::new)),
104 },
105 }
106 }
107
108 // Return HTTP 400 for all other requests.
109 RequestKind::Other(_) => {
110 debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE));
111
112 ResponseFuture {
113 case: Case::immediate(StatusCode::BAD_REQUEST),
114 }
115 }
116 }
117 }
118 }
119
120 /// Response future for the [`GrpcWebService`].
121 #[pin_project]
122 #[must_use = "futures do nothing unless polled"]
123 pub struct ResponseFuture<F> {
124 #[pin]
125 case: Case<F>,
126 }
127
128 #[pin_project(project = CaseProj)]
129 enum Case<F> {
130 GrpcWeb {
131 #[pin]
132 future: F,
133 accept: Encoding,
134 },
135 Other {
136 #[pin]
137 future: F,
138 },
139 ImmediateResponse {
140 res: Option<http::response::Parts>,
141 },
142 }
143
144 impl<F> Case<F> {
immediate(status: StatusCode) -> Self145 fn immediate(status: StatusCode) -> Self {
146 let (res, ()) = Response::builder()
147 .status(status)
148 .body(())
149 .unwrap()
150 .into_parts();
151 Self::ImmediateResponse { res: Some(res) }
152 }
153 }
154
155 impl<F, E> Future for ResponseFuture<F>
156 where
157 F: Future<Output = Result<Response<Body>, E>>,
158 {
159 type Output = Result<Response<Body>, E>;
160
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>161 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162 let this = self.project();
163
164 match this.case.project() {
165 CaseProj::GrpcWeb { future, accept } => {
166 let res = ready!(future.poll(cx))?;
167
168 Poll::Ready(Ok(coerce_response(res, *accept)))
169 }
170 CaseProj::Other { future } => future.poll(cx),
171 CaseProj::ImmediateResponse { res } => {
172 let res = Response::from_parts(res.take().unwrap(), Body::empty());
173 Poll::Ready(Ok(res))
174 }
175 }
176 }
177 }
178
179 impl<S: NamedService> NamedService for GrpcWebService<S> {
180 const NAME: &'static str = S::NAME;
181 }
182
183 impl<F> fmt::Debug for ResponseFuture<F> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("ResponseFuture").finish()
186 }
187 }
188
189 impl<'a> RequestKind<'a> {
new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self190 fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self {
191 if is_grpc_web(headers) {
192 return RequestKind::GrpcWeb {
193 method,
194 encoding: Encoding::from_content_type(headers),
195 accept: Encoding::from_accept(headers),
196 };
197 }
198
199 RequestKind::Other(version)
200 }
201 }
202
203 // Mutating request headers to conform to a gRPC request is not really
204 // necessary for us at this point. We could remove most of these except
205 // maybe for inserting `header::TE`, which tonic should check?
coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body> where B: http_body::Body<Data = bytes::Bytes> + Send + 'static, B::Error: Into<crate::BoxError> + fmt::Display,206 fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
207 where
208 B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
209 B::Error: Into<crate::BoxError> + fmt::Display,
210 {
211 req.headers_mut().remove(header::CONTENT_LENGTH);
212
213 req.headers_mut()
214 .insert(header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
215
216 req.headers_mut()
217 .insert(header::TE, HeaderValue::from_static("trailers"));
218
219 req.headers_mut().insert(
220 header::ACCEPT_ENCODING,
221 HeaderValue::from_static("identity,deflate,gzip"),
222 );
223
224 req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
225 }
226
coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body> where B: http_body::Body<Data = bytes::Bytes> + Send + 'static, B::Error: Into<crate::BoxError> + fmt::Display,227 fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
228 where
229 B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
230 B::Error: Into<crate::BoxError> + fmt::Display,
231 {
232 let mut res = res
233 .map(|b| GrpcWebCall::response(b, encoding))
234 .map(Body::new);
235
236 res.headers_mut().insert(
237 header::CONTENT_TYPE,
238 HeaderValue::from_static(encoding.to_content_type()),
239 );
240
241 res
242 }
243
244 #[cfg(test)]
245 mod tests {
246 use super::*;
247 use crate::call::content_types::*;
248 use http::header::{
249 ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
250 };
251 use tower_layer::Layer as _;
252
253 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
254
255 #[derive(Debug, Clone)]
256 struct Svc;
257
258 impl tower_service::Service<Request<Body>> for Svc {
259 type Response = Response<Body>;
260 type Error = String;
261 type Future = BoxFuture<Self::Response, Self::Error>;
262
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>>263 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
264 Poll::Ready(Ok(()))
265 }
266
call(&mut self, _: Request<Body>) -> Self::Future267 fn call(&mut self, _: Request<Body>) -> Self::Future {
268 Box::pin(async { Ok(Response::new(Body::default())) })
269 }
270 }
271
272 impl NamedService for Svc {
273 const NAME: &'static str = "test";
274 }
275
enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>> where S: Service<http::Request<Body>, Response = http::Response<Body>>,276 fn enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>>
277 where
278 S: Service<http::Request<Body>, Response = http::Response<Body>>,
279 {
280 tower_layer::Stack::new(
281 crate::GrpcWebLayer::new(),
282 tower_http::cors::CorsLayer::new(),
283 )
284 .layer(service)
285 }
286
287 mod grpc_web {
288 use super::*;
289 use tower_layer::Layer;
290
request() -> Request<Body>291 fn request() -> Request<Body> {
292 Request::builder()
293 .method(Method::POST)
294 .header(CONTENT_TYPE, GRPC_WEB)
295 .header(ORIGIN, "http://example.com")
296 .body(Body::default())
297 .unwrap()
298 }
299
300 #[tokio::test]
default_cors_config()301 async fn default_cors_config() {
302 let mut svc = enable(Svc);
303 let res = svc.call(request()).await.unwrap();
304
305 assert_eq!(res.status(), StatusCode::OK);
306 }
307
308 #[tokio::test]
web_layer()309 async fn web_layer() {
310 let mut svc = crate::GrpcWebLayer::new().layer(Svc);
311 let res = svc.call(request()).await.unwrap();
312
313 assert_eq!(res.status(), StatusCode::OK);
314 }
315
316 #[tokio::test]
without_origin()317 async fn without_origin() {
318 let mut svc = enable(Svc);
319
320 let mut req = request();
321 req.headers_mut().remove(ORIGIN);
322
323 let res = svc.call(req).await.unwrap();
324
325 assert_eq!(res.status(), StatusCode::OK);
326 }
327
328 #[tokio::test]
only_post_and_options_allowed()329 async fn only_post_and_options_allowed() {
330 let mut svc = enable(Svc);
331
332 for method in &[
333 Method::GET,
334 Method::PUT,
335 Method::DELETE,
336 Method::HEAD,
337 Method::PATCH,
338 ] {
339 let mut req = request();
340 *req.method_mut() = method.clone();
341
342 let res = svc.call(req).await.unwrap();
343
344 assert_eq!(
345 res.status(),
346 StatusCode::METHOD_NOT_ALLOWED,
347 "{} should not be allowed",
348 method
349 );
350 }
351 }
352
353 #[tokio::test]
grpc_web_content_types()354 async fn grpc_web_content_types() {
355 let mut svc = enable(Svc);
356
357 for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_TEXT_PROTO, GRPC_WEB] {
358 let mut req = request();
359 req.headers_mut()
360 .insert(CONTENT_TYPE, HeaderValue::from_static(ct));
361
362 let res = svc.call(req).await.unwrap();
363
364 assert_eq!(res.status(), StatusCode::OK);
365 }
366 }
367 }
368
369 mod options {
370 use super::*;
371
request() -> Request<Body>372 fn request() -> Request<Body> {
373 Request::builder()
374 .method(Method::OPTIONS)
375 .header(ORIGIN, "http://example.com")
376 .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web")
377 .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
378 .body(Body::default())
379 .unwrap()
380 }
381
382 #[tokio::test]
valid_grpc_web_preflight()383 async fn valid_grpc_web_preflight() {
384 let mut svc = enable(Svc);
385 let res = svc.call(request()).await.unwrap();
386
387 assert_eq!(res.status(), StatusCode::OK);
388 }
389 }
390
391 mod grpc {
392 use super::*;
393
request() -> Request<Body>394 fn request() -> Request<Body> {
395 Request::builder()
396 .version(Version::HTTP_2)
397 .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
398 .body(Body::default())
399 .unwrap()
400 }
401
402 #[tokio::test]
h2_is_ok()403 async fn h2_is_ok() {
404 let mut svc = enable(Svc);
405
406 let req = request();
407 let res = svc.call(req).await.unwrap();
408
409 assert_eq!(res.status(), StatusCode::OK)
410 }
411
412 #[tokio::test]
h1_is_err()413 async fn h1_is_err() {
414 let mut svc = enable(Svc);
415
416 let req = Request::builder()
417 .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
418 .body(Body::default())
419 .unwrap();
420
421 let res = svc.call(req).await.unwrap();
422 assert_eq!(res.status(), StatusCode::BAD_REQUEST)
423 }
424
425 #[tokio::test]
content_type_variants()426 async fn content_type_variants() {
427 let mut svc = enable(Svc);
428
429 for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] {
430 let mut req = request();
431 req.headers_mut().insert(
432 CONTENT_TYPE,
433 HeaderValue::from_maybe_shared(format!("application/{}", variant)).unwrap(),
434 );
435
436 let res = svc.call(req).await.unwrap();
437
438 assert_eq!(res.status(), StatusCode::OK)
439 }
440 }
441 }
442
443 mod other {
444 use super::*;
445
request() -> Request<Body>446 fn request() -> Request<Body> {
447 Request::builder()
448 .header(CONTENT_TYPE, "application/text")
449 .body(Body::default())
450 .unwrap()
451 }
452
453 #[tokio::test]
h1_is_err()454 async fn h1_is_err() {
455 let mut svc = enable(Svc);
456 let res = svc.call(request()).await.unwrap();
457
458 assert_eq!(res.status(), StatusCode::BAD_REQUEST)
459 }
460
461 #[tokio::test]
h2_is_ok()462 async fn h2_is_ok() {
463 let mut svc = enable(Svc);
464 let mut req = request();
465 *req.version_mut() = Version::HTTP_2;
466
467 let res = svc.call(req).await.unwrap();
468 assert_eq!(res.status(), StatusCode::OK)
469 }
470 }
471 }
472