use crate::{
body::{Body, BoxBody},
client::GrpcService,
codec::{encode_client, Codec, Streaming},
interceptor::Interceptor,
Code, Request, Response, Status,
};
use futures_core::Stream;
use futures_util::{future, stream, TryStreamExt};
use http::{
header::{HeaderValue, CONTENT_TYPE, TE},
uri::{Parts, PathAndQuery, Uri},
};
use http_body::Body as HttpBody;
use std::fmt;
pub struct Grpc<T> {
inner: T,
interceptor: Option<Interceptor>,
}
impl<T> Grpc<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
interceptor: None,
}
}
pub fn with_interceptor(inner: T, interceptor: impl Into<Interceptor>) -> Self {
Self {
inner,
interceptor: Some(interceptor.into()),
}
}
pub async fn ready(&mut self) -> Result<(), T::Error>
where
T: GrpcService<BoxBody>,
{
future::poll_fn(|cx| self.inner.poll_ready(cx)).await
}
pub async fn unary<M1, M2, C>(
&mut self,
request: Request<M1>,
path: PathAndQuery,
codec: C,
) -> Result<Response<M2>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::Error>,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request.map(|m| stream::once(future::ready(m)));
self.client_streaming(request, path, codec).await
}
pub async fn client_streaming<S, M1, M2, C>(
&mut self,
request: Request<S>,
path: PathAndQuery,
codec: C,
) -> Result<Response<M2>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::Error>,
S: Stream<Item = M1> + Send + Sync + 'static,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let (mut parts, body) = self.streaming(request, path, codec).await?.into_parts();
futures_util::pin_mut!(body);
let message = body
.try_next()
.await?
.ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?;
if let Some(trailers) = body.trailers().await? {
parts.merge(trailers);
}
Ok(Response::from_parts(parts, message))
}
pub async fn server_streaming<M1, M2, C>(
&mut self,
request: Request<M1>,
path: PathAndQuery,
codec: C,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::Error>,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request.map(|m| stream::once(future::ready(m)));
self.streaming(request, path, codec).await
}
pub async fn streaming<S, M1, M2, C>(
&mut self,
request: Request<S>,
path: PathAndQuery,
mut codec: C,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::Error>,
S: Stream<Item = M1> + Send + Sync + 'static,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = if let Some(interceptor) = &self.interceptor {
interceptor.call(request)?
} else {
request
};
let mut parts = Parts::default();
parts.path_and_query = Some(path);
let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
let request = request
.map(|s| encode_client(codec.encoder(), s))
.map(BoxBody::new);
let mut request = request.into_http(uri);
request
.headers_mut()
.insert(TE, HeaderValue::from_static("trailers"));
request
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc"));
let response = self
.inner
.call(request)
.await
.map_err(|err| Status::from_error(&*(err.into())))?;
let status_code = response.status();
let trailers_only_status = Status::from_header_map(response.headers());
let expect_additional_trailers = if let Some(status) = trailers_only_status {
if status.code() != Code::Ok {
return Err(status);
}
false
} else {
true
};
let response = response.map(|body| {
if expect_additional_trailers {
Streaming::new_response(codec.decoder(), body, status_code)
} else {
Streaming::new_empty(codec.decoder(), body)
}
});
Ok(Response::from_http(response))
}
}
impl<T: Clone> Clone for Grpc<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
interceptor: self.interceptor.clone(),
}
}
}
impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Grpc").field("inner", &self.inner).finish()
}
}