1 use http::header::CONTENT_TYPE; 2 use http::{Request, Response, Version}; 3 use pin_project::pin_project; 4 use std::fmt; 5 use std::future::Future; 6 use std::pin::Pin; 7 use std::task::{ready, Context, Poll}; 8 use tower_layer::Layer; 9 use tower_service::Service; 10 use tracing::debug; 11 12 use crate::call::content_types::GRPC_WEB; 13 use crate::call::GrpcWebCall; 14 15 /// Layer implementing the grpc-web protocol for clients. 16 #[derive(Debug, Default, Clone)] 17 pub struct GrpcWebClientLayer { 18 _priv: (), 19 } 20 21 impl GrpcWebClientLayer { 22 /// Create a new grpc-web for clients layer. new() -> GrpcWebClientLayer23 pub fn new() -> GrpcWebClientLayer { 24 Self::default() 25 } 26 } 27 28 impl<S> Layer<S> for GrpcWebClientLayer { 29 type Service = GrpcWebClientService<S>; 30 layer(&self, inner: S) -> Self::Service31 fn layer(&self, inner: S) -> Self::Service { 32 GrpcWebClientService::new(inner) 33 } 34 } 35 36 /// A [`Service`] that wraps some inner http service that will 37 /// coerce requests coming from [`tonic::client::Grpc`] into proper 38 /// `grpc-web` requests. 39 #[derive(Debug, Clone)] 40 pub struct GrpcWebClientService<S> { 41 inner: S, 42 } 43 44 impl<S> GrpcWebClientService<S> { 45 /// Create a new grpc-web for clients service. new(inner: S) -> Self46 pub fn new(inner: S) -> Self { 47 Self { inner } 48 } 49 } 50 51 impl<S, B1, B2> Service<Request<B1>> for GrpcWebClientService<S> 52 where 53 S: Service<Request<GrpcWebCall<B1>>, Response = Response<B2>>, 54 { 55 type Response = Response<GrpcWebCall<B2>>; 56 type Error = S::Error; 57 type Future = ResponseFuture<S::Future>; 58 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>59 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 60 self.inner.poll_ready(cx) 61 } 62 call(&mut self, mut req: Request<B1>) -> Self::Future63 fn call(&mut self, mut req: Request<B1>) -> Self::Future { 64 if req.version() == Version::HTTP_2 { 65 debug!("coercing HTTP2 request to HTTP1.1"); 66 67 *req.version_mut() = Version::HTTP_11; 68 } 69 70 req.headers_mut() 71 .insert(CONTENT_TYPE, GRPC_WEB.try_into().unwrap()); 72 73 let req = req.map(GrpcWebCall::client_request); 74 75 let fut = self.inner.call(req); 76 77 ResponseFuture { inner: fut } 78 } 79 } 80 81 /// Response future for the [`GrpcWebService`](crate::GrpcWebService). 82 #[pin_project] 83 #[must_use = "futures do nothing unless polled"] 84 pub struct ResponseFuture<F> { 85 #[pin] 86 inner: F, 87 } 88 89 impl<F, B, E> Future for ResponseFuture<F> 90 where 91 F: Future<Output = Result<Response<B>, E>>, 92 { 93 type Output = Result<Response<GrpcWebCall<B>>, E>; 94 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>95 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 96 let res = ready!(self.project().inner.poll(cx)); 97 98 Poll::Ready(res.map(|r| r.map(GrpcWebCall::client_response))) 99 } 100 } 101 102 impl<F> fmt::Debug for ResponseFuture<F> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 104 f.debug_struct("ResponseFuture").finish() 105 } 106 } 107