mod conn;
mod incoming;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
pub use conn::Connected;
#[cfg(feature = "tls")]
pub use tls::ServerTlsConfig;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;
use incoming::TcpIncoming;
#[cfg(feature = "tls")]
pub(crate) use incoming::TlsStream;
#[cfg(feature = "tls")]
use crate::transport::Error;
use super::service::{Or, Routes, ServerIo, ServiceBuilderExt};
use crate::{body::BoxBody, request::ConnectionInfo};
use futures_core::Stream;
use futures_util::{
future::{self, MapErr},
TryFutureExt,
};
use http::{HeaderMap, Request, Response};
use hyper::{server::accept, Body};
use std::{
fmt,
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{
limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, util::Either, Service,
ServiceBuilder,
};
use tracing_futures::{Instrument, Instrumented};
type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
type TraceInterceptor = Arc<dyn Fn(&HeaderMap) -> tracing::Span + Send + Sync + 'static>;
#[derive(Default, Clone)]
pub struct Server {
trace_interceptor: Option<TraceInterceptor>,
concurrency_limit: Option<usize>,
timeout: Option<Duration>,
#[cfg(feature = "tls")]
tls: Option<TlsAcceptor>,
init_stream_window_size: Option<u32>,
init_connection_window_size: Option<u32>,
max_concurrent_streams: Option<u32>,
tcp_keepalive: Option<Duration>,
tcp_nodelay: bool,
}
#[derive(Debug)]
pub struct Router<A, B> {
server: Server,
routes: Routes<A, B, Request<Body>>,
}
pub trait NamedService {
const NAME: &'static str;
}
impl<S: NamedService, T> NamedService for Either<S, T> {
const NAME: &'static str = S::NAME;
}
impl Server {
pub fn builder() -> Self {
Server {
tcp_nodelay: true,
..Default::default()
}
}
}
impl Server {
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub fn tls_config(self, tls_config: ServerTlsConfig) -> Result<Self, Error> {
Ok(Server {
tls: Some(
tls_config
.tls_acceptor()
.map_err(|e| Error::from_source(e))?,
),
..self
})
}
pub fn concurrency_limit_per_connection(self, limit: usize) -> Self {
Server {
concurrency_limit: Some(limit),
..self
}
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
Server {
init_stream_window_size: sz.into(),
..self
}
}
pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
Server {
init_connection_window_size: sz.into(),
..self
}
}
pub fn max_concurrent_streams(self, max: impl Into<Option<u32>>) -> Self {
Server {
max_concurrent_streams: max.into(),
..self
}
}
pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
Server {
tcp_keepalive,
..self
}
}
pub fn tcp_nodelay(self, enabled: bool) -> Self {
Server {
tcp_nodelay: enabled,
..self
}
}
pub fn trace_fn<F>(self, f: F) -> Self
where
F: Fn(&HeaderMap) -> tracing::Span + Send + Sync + 'static,
{
Server {
trace_interceptor: Some(Arc::new(f)),
..self
}
}
pub fn add_service<S>(&mut self, svc: S) -> Router<S, Unimplemented>
where
S: Service<Request<Body>, Response = Response<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
Router::new(self.clone(), svc)
}
pub fn add_optional_service<S>(
&mut self,
svc: Option<S>,
) -> Router<Either<S, Unimplemented>, Unimplemented>
where
S: Service<Request<Body>, Response = Response<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
let svc = match svc {
Some(some) => Either::A(some),
None => Either::B(Unimplemented::default()),
};
Router::new(self.clone(), svc)
}
pub(crate) async fn serve_with_shutdown<S, I, F, IO, IE>(
self,
svc: S,
incoming: I,
signal: Option<F>,
) -> Result<(), super::Error>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
F: Future<Output = ()>,
{
let span = self.trace_interceptor.clone();
let concurrency_limit = self.concurrency_limit;
let init_connection_window_size = self.init_connection_window_size;
let init_stream_window_size = self.init_stream_window_size;
let max_concurrent_streams = self.max_concurrent_streams;
let timeout = self.timeout.clone();
let tcp = incoming::tcp_incoming(incoming, self);
let incoming = accept::from_stream::<_, _, crate::Error>(tcp);
let svc = MakeSvc {
inner: svc,
concurrency_limit,
timeout,
span,
};
let server = hyper::Server::builder(incoming)
.http2_only(true)
.http2_initial_connection_window_size(init_connection_window_size)
.http2_initial_stream_window_size(init_stream_window_size)
.http2_max_concurrent_streams(max_concurrent_streams);
if let Some(signal) = signal {
server
.serve(svc)
.with_graceful_shutdown(signal)
.await
.map_err(super::Error::from_source)?
} else {
server.serve(svc).await.map_err(super::Error::from_source)?;
}
Ok(())
}
}
impl<S> Router<S, Unimplemented> {
pub(crate) fn new(server: Server, svc: S) -> Self
where
S: Service<Request<Body>, Response = Response<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
let svc_name = <S as NamedService>::NAME;
let svc_route = format!("/{}", svc_name);
let pred = move |req: &Request<Body>| {
let path = req.uri().path();
path.starts_with(&svc_route)
};
Self {
server,
routes: Routes::new(pred, svc, Unimplemented::default()),
}
}
}
impl<A, B> Router<A, B>
where
A: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
A::Future: Send + 'static,
A::Error: Into<crate::Error> + Send,
B: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
B::Future: Send + 'static,
B::Error: Into<crate::Error> + Send,
{
pub fn add_service<S>(self, svc: S) -> Router<S, Or<A, B, Request<Body>>>
where
S: Service<Request<Body>, Response = Response<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
let Self { routes, server } = self;
let svc_name = <S as NamedService>::NAME;
let svc_route = format!("/{}", svc_name);
let pred = move |req: &Request<Body>| {
let path = req.uri().path();
path.starts_with(&svc_route)
};
let routes = routes.push(pred, svc);
Router { server, routes }
}
pub fn add_optional_service<S>(
self,
svc: Option<S>,
) -> Router<Either<S, Unimplemented>, Or<A, B, Request<Body>>>
where
S: Service<Request<Body>, Response = Response<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
let Self { routes, server } = self;
let svc_name = <S as NamedService>::NAME;
let svc_route = format!("/{}", svc_name);
let pred = move |req: &Request<Body>| {
let path = req.uri().path();
path.starts_with(&svc_route)
};
let svc = match svc {
Some(some) => Either::A(some),
None => Either::B(Unimplemented::default()),
};
let routes = routes.push(pred, svc);
Router { server, routes }
}
pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> {
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
self.server
.serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None)
.await
}
pub async fn serve_with_shutdown<F: Future<Output = ()>>(
self,
addr: SocketAddr,
signal: F,
) -> Result<(), super::Error> {
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
self.server
.serve_with_shutdown(self.routes, incoming, Some(signal))
.await
}
pub async fn serve_with_incoming<I, IO, IE>(self, incoming: I) -> Result<(), super::Error>
where
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
self.server
.serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None)
.await
}
pub async fn serve_with_incoming_shutdown<I, IO, IE, F>(
self,
incoming: I,
signal: F,
) -> Result<(), super::Error>
where
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
F: Future<Output = ()>,
{
self.server
.serve_with_shutdown(self.routes, incoming, Some(signal))
.await
}
}
impl fmt::Debug for Server {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Builder").finish()
}
}
struct Svc<S> {
inner: S,
span: Option<TraceInterceptor>,
conn_info: ConnectionInfo,
}
impl<S> Service<Request<Body>> for Svc<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>>,
S::Error: Into<crate::Error>,
{
type Response = Response<BoxBody>;
type Error = crate::Error;
type Future = MapErr<Instrumented<S::Future>, fn(S::Error) -> crate::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let span = if let Some(trace_interceptor) = &self.span {
trace_interceptor(req.headers())
} else {
tracing::Span::none()
};
req.extensions_mut().insert(self.conn_info.clone());
self.inner.call(req).instrument(span).map_err(|e| e.into())
}
}
impl<S> fmt::Debug for Svc<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Svc").finish()
}
}
struct MakeSvc<S> {
concurrency_limit: Option<usize>,
timeout: Option<Duration>,
inner: S,
span: Option<TraceInterceptor>,
}
impl<S> Service<&ServerIo> for MakeSvc<S>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
{
type Response = BoxService;
type Error = crate::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn call(&mut self, io: &ServerIo) -> Self::Future {
let conn_info = crate::request::ConnectionInfo {
remote_addr: io.remote_addr(),
peer_certs: io.peer_certs().map(Arc::new),
};
let svc = self.inner.clone();
let concurrency_limit = self.concurrency_limit;
let timeout = self.timeout.clone();
let span = self.span.clone();
Box::pin(async move {
let svc = ServiceBuilder::new()
.optional_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.optional_layer(timeout.map(TimeoutLayer::new))
.service(svc);
let svc = BoxService::new(Svc {
inner: svc,
span,
conn_info,
});
Ok(svc)
})
}
}
#[derive(Default, Clone, Debug)]
#[doc(hidden)]
pub struct Unimplemented {
_p: (),
}
impl Service<Request<Body>> for Unimplemented {
type Response = Response<BoxBody>;
type Error = crate::Error;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn call(&mut self, _req: Request<Body>) -> Self::Future {
future::ok(
http::Response::builder()
.status(200)
.header("grpc-status", "12")
.body(BoxBody::empty())
.unwrap(),
)
}
}