14d2667d1SDavid Pedersen use std::{
2c9c4acbcStottoto pin::Pin,
34d2667d1SDavid Pedersen task::{Context, Poll},
480e90e3bSKaleb Elwert };
51359d7adStottoto use tonic::{transport::Server, Request, Response, Status};
64d2667d1SDavid Pedersen use tower::{Layer, Service};
780e90e3bSKaleb Elwert
880e90e3bSKaleb Elwert use hello_world::greeter_server::{Greeter, GreeterServer};
980e90e3bSKaleb Elwert use hello_world::{HelloReply, HelloRequest};
1080e90e3bSKaleb Elwert
1180e90e3bSKaleb Elwert pub mod hello_world {
1280e90e3bSKaleb Elwert tonic::include_proto!("helloworld");
1380e90e3bSKaleb Elwert }
1480e90e3bSKaleb Elwert
1580e90e3bSKaleb Elwert #[derive(Default)]
1680e90e3bSKaleb Elwert pub struct MyGreeter {}
1780e90e3bSKaleb Elwert
1880e90e3bSKaleb Elwert #[tonic::async_trait]
1980e90e3bSKaleb Elwert impl Greeter for MyGreeter {
say_hello( &self, request: Request<HelloRequest>, ) -> Result<Response<HelloReply>, Status>2080e90e3bSKaleb Elwert async fn say_hello(
2180e90e3bSKaleb Elwert &self,
2280e90e3bSKaleb Elwert request: Request<HelloRequest>,
2380e90e3bSKaleb Elwert ) -> Result<Response<HelloReply>, Status> {
2480e90e3bSKaleb Elwert println!("Got a request from {:?}", request.remote_addr());
2580e90e3bSKaleb Elwert
2680e90e3bSKaleb Elwert let reply = hello_world::HelloReply {
2780e90e3bSKaleb Elwert message: format!("Hello {}!", request.into_inner().name),
2880e90e3bSKaleb Elwert };
2980e90e3bSKaleb Elwert Ok(Response::new(reply))
3080e90e3bSKaleb Elwert }
3180e90e3bSKaleb Elwert }
3280e90e3bSKaleb Elwert
3380e90e3bSKaleb Elwert #[tokio::main]
main() -> Result<(), Box<dyn std::error::Error>>3480e90e3bSKaleb Elwert async fn main() -> Result<(), Box<dyn std::error::Error>> {
3580e90e3bSKaleb Elwert let addr = "[::1]:50051".parse().unwrap();
3680e90e3bSKaleb Elwert let greeter = MyGreeter::default();
3780e90e3bSKaleb Elwert
3880e90e3bSKaleb Elwert println!("GreeterServer listening on {}", addr);
3980e90e3bSKaleb Elwert
404d2667d1SDavid Pedersen let svc = GreeterServer::new(greeter);
4180e90e3bSKaleb Elwert
424d2667d1SDavid Pedersen // The stack of middleware that our service will be wrapped in
434d2667d1SDavid Pedersen let layer = tower::ServiceBuilder::new()
444d2667d1SDavid Pedersen // Apply our own middleware
454d2667d1SDavid Pedersen .layer(MyMiddlewareLayer::default())
464d2667d1SDavid Pedersen // Interceptors can be also be applied as middleware
477b93470bStottoto .layer(tonic::service::InterceptorLayer::new(intercept))
484d2667d1SDavid Pedersen .into_inner();
494d2667d1SDavid Pedersen
504d2667d1SDavid Pedersen Server::builder()
514d2667d1SDavid Pedersen // Wrap all services in the middleware stack
524d2667d1SDavid Pedersen .layer(layer)
534d2667d1SDavid Pedersen .add_service(svc)
544d2667d1SDavid Pedersen .serve(addr)
554d2667d1SDavid Pedersen .await?;
5680e90e3bSKaleb Elwert
5780e90e3bSKaleb Elwert Ok(())
5880e90e3bSKaleb Elwert }
5980e90e3bSKaleb Elwert
604d2667d1SDavid Pedersen // An interceptor function.
intercept(req: Request<()>) -> Result<Request<()>, Status>614d2667d1SDavid Pedersen fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
624d2667d1SDavid Pedersen Ok(req)
634d2667d1SDavid Pedersen }
644d2667d1SDavid Pedersen
654d2667d1SDavid Pedersen #[derive(Debug, Clone, Default)]
669f5fc3b9Stottoto struct MyMiddlewareLayer {}
674d2667d1SDavid Pedersen
684d2667d1SDavid Pedersen impl<S> Layer<S> for MyMiddlewareLayer {
694d2667d1SDavid Pedersen type Service = MyMiddleware<S>;
704d2667d1SDavid Pedersen
layer(&self, service: S) -> Self::Service714d2667d1SDavid Pedersen fn layer(&self, service: S) -> Self::Service {
724d2667d1SDavid Pedersen MyMiddleware { inner: service }
734d2667d1SDavid Pedersen }
744d2667d1SDavid Pedersen }
754d2667d1SDavid Pedersen
7680e90e3bSKaleb Elwert #[derive(Debug, Clone)]
774d2667d1SDavid Pedersen struct MyMiddleware<S> {
7880e90e3bSKaleb Elwert inner: S,
7980e90e3bSKaleb Elwert }
8080e90e3bSKaleb Elwert
81c9c4acbcStottoto type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
82c9c4acbcStottoto
831359d7adStottoto impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for MyMiddleware<S>
8480e90e3bSKaleb Elwert where
851359d7adStottoto S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone + Send + 'static,
8680e90e3bSKaleb Elwert S::Future: Send + 'static,
871359d7adStottoto ReqBody: Send + 'static,
8880e90e3bSKaleb Elwert {
8980e90e3bSKaleb Elwert type Response = S::Response;
9080e90e3bSKaleb Elwert type Error = S::Error;
91c9c4acbcStottoto type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
9280e90e3bSKaleb Elwert
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>9380e90e3bSKaleb Elwert fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
9480e90e3bSKaleb Elwert self.inner.poll_ready(cx)
9580e90e3bSKaleb Elwert }
9680e90e3bSKaleb Elwert
call(&mut self, req: http::Request<ReqBody>) -> Self::Future971359d7adStottoto fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
98*d40f610dSShikhar Bhushan // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
994a917a32SDavid Pedersen let clone = self.inner.clone();
1004a917a32SDavid Pedersen let mut inner = std::mem::replace(&mut self.inner, clone);
10180e90e3bSKaleb Elwert
10280e90e3bSKaleb Elwert Box::pin(async move {
1034a917a32SDavid Pedersen // Do extra async work here...
1044a917a32SDavid Pedersen let response = inner.call(req).await?;
10580e90e3bSKaleb Elwert
1064a917a32SDavid Pedersen Ok(response)
10780e90e3bSKaleb Elwert })
10880e90e3bSKaleb Elwert }
10980e90e3bSKaleb Elwert }
110