xref: /tonic/tonic-web/src/client.rs (revision 5e9a5bcd)
1 use http::header::CONTENT_TYPE;
2 use http::{Request, Response, Version};
3 use pin_project::pin_project;
4 use std::fmt;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::task::{ready, Context, Poll};
8 use tower_layer::Layer;
9 use tower_service::Service;
10 use tracing::debug;
11 
12 use crate::call::content_types::GRPC_WEB;
13 use crate::call::GrpcWebCall;
14 
15 /// Layer implementing the grpc-web protocol for clients.
16 #[derive(Debug, Default, Clone)]
17 pub struct GrpcWebClientLayer {
18     _priv: (),
19 }
20 
21 impl GrpcWebClientLayer {
22     /// Create a new grpc-web for clients layer.
new() -> GrpcWebClientLayer23     pub fn new() -> GrpcWebClientLayer {
24         Self::default()
25     }
26 }
27 
28 impl<S> Layer<S> for GrpcWebClientLayer {
29     type Service = GrpcWebClientService<S>;
30 
layer(&self, inner: S) -> Self::Service31     fn layer(&self, inner: S) -> Self::Service {
32         GrpcWebClientService::new(inner)
33     }
34 }
35 
36 /// A [`Service`] that wraps some inner http service that will
37 /// coerce requests coming from [`tonic::client::Grpc`] into proper
38 /// `grpc-web` requests.
39 #[derive(Debug, Clone)]
40 pub struct GrpcWebClientService<S> {
41     inner: S,
42 }
43 
44 impl<S> GrpcWebClientService<S> {
45     /// Create a new grpc-web for clients service.
new(inner: S) -> Self46     pub fn new(inner: S) -> Self {
47         Self { inner }
48     }
49 }
50 
51 impl<S, B1, B2> Service<Request<B1>> for GrpcWebClientService<S>
52 where
53     S: Service<Request<GrpcWebCall<B1>>, Response = Response<B2>>,
54 {
55     type Response = Response<GrpcWebCall<B2>>;
56     type Error = S::Error;
57     type Future = ResponseFuture<S::Future>;
58 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>59     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
60         self.inner.poll_ready(cx)
61     }
62 
call(&mut self, mut req: Request<B1>) -> Self::Future63     fn call(&mut self, mut req: Request<B1>) -> Self::Future {
64         if req.version() == Version::HTTP_2 {
65             debug!("coercing HTTP2 request to HTTP1.1");
66 
67             *req.version_mut() = Version::HTTP_11;
68         }
69 
70         req.headers_mut()
71             .insert(CONTENT_TYPE, GRPC_WEB.try_into().unwrap());
72 
73         let req = req.map(GrpcWebCall::client_request);
74 
75         let fut = self.inner.call(req);
76 
77         ResponseFuture { inner: fut }
78     }
79 }
80 
81 /// Response future for the [`GrpcWebService`](crate::GrpcWebService).
82 #[pin_project]
83 #[must_use = "futures do nothing unless polled"]
84 pub struct ResponseFuture<F> {
85     #[pin]
86     inner: F,
87 }
88 
89 impl<F, B, E> Future for ResponseFuture<F>
90 where
91     F: Future<Output = Result<Response<B>, E>>,
92 {
93     type Output = Result<Response<GrpcWebCall<B>>, E>;
94 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>95     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96         let res = ready!(self.project().inner.poll(cx));
97 
98         Poll::Ready(res.map(|r| r.map(GrpcWebCall::client_response)))
99     }
100 }
101 
102 impl<F> fmt::Debug for ResponseFuture<F> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result103     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104         f.debug_struct("ResponseFuture").finish()
105     }
106 }
107