use super::{DecodeBuf, Decoder};
use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
use bytes::{Buf, BufMut, BytesMut};
use futures_core::Stream;
use futures_util::{future, ready};
use http::StatusCode;
use http_body::Body;
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
use tracing::{debug, trace};
const BUFFER_SIZE: usize = 8 * 1024;
pub struct Streaming<T> {
decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + Sync + 'static>,
body: BoxBody,
state: State,
direction: Direction,
buf: BytesMut,
trailers: Option<MetadataMap>,
}
impl<T> Unpin for Streaming<T> {}
#[derive(Debug)]
enum State {
ReadHeader,
ReadBody { compression: bool, len: usize },
}
#[derive(Debug)]
enum Direction {
Request,
Response(StatusCode),
EmptyResponse,
}
impl<T> Streaming<T> {
pub(crate) fn new_response<B, D>(decoder: D, body: B, status_code: StatusCode) -> Self
where
B: Body + Send + Sync + 'static,
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + Sync + 'static,
{
Self::new(decoder, body, Direction::Response(status_code))
}
pub(crate) fn new_empty<B, D>(decoder: D, body: B) -> Self
where
B: Body + Send + Sync + 'static,
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + Sync + 'static,
{
Self::new(decoder, body, Direction::EmptyResponse)
}
#[doc(hidden)]
pub fn new_request<B, D>(decoder: D, body: B) -> Self
where
B: Body + Send + Sync + 'static,
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + Sync + 'static,
{
Self::new(decoder, body, Direction::Request)
}
fn new<B, D>(decoder: D, body: B, direction: Direction) -> Self
where
B: Body + Send + Sync + 'static,
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + Sync + 'static,
{
Self {
decoder: Box::new(decoder),
body: BoxBody::map_from(body),
state: State::ReadHeader,
direction,
buf: BytesMut::with_capacity(BUFFER_SIZE),
trailers: None,
}
}
}
impl<T> Streaming<T> {
pub async fn message(&mut self) -> Result<Option<T>, Status> {
match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
Some(Ok(m)) => Ok(Some(m)),
Some(Err(e)) => Err(e),
None => Ok(None),
}
}
pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
if let Some(trailers) = self.trailers.take() {
return Ok(Some(trailers));
}
while let Some(_) = self.message().await? {}
if let Some(trailers) = self.trailers.take() {
return Ok(Some(trailers));
}
let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx))
.await
.map_err(|e| Status::from_error(&e))?;
Ok(map.map(MetadataMap::from_headers))
}
fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
if let State::ReadHeader = self.state {
if self.buf.remaining() < 5 {
return Ok(None);
}
let is_compressed = match self.buf.get_u8() {
0 => false,
1 => {
trace!("message compressed, compression not supported yet");
return Err(Status::new(
Code::Unimplemented,
"Message compressed, compression not supported yet.".to_string(),
));
}
f => {
trace!("unexpected compression flag");
let message = if let Direction::Response(status) = self.direction {
format!(
"Unexpected compression flag: {}, while receiving response with status: {}",
f, status
)
} else {
format!("Unexpected compression flag: {}, while sending request", f)
};
return Err(Status::new(Code::Internal, message));
}
};
let len = self.buf.get_u32() as usize;
self.state = State::ReadBody {
compression: is_compressed,
len,
}
}
if let State::ReadBody { len, .. } = &self.state {
if self.buf.remaining() < *len || self.buf.len() < *len {
return Ok(None);
}
return match self
.decoder
.decode(&mut DecodeBuf::new(&mut self.buf, *len))
{
Ok(Some(msg)) => {
self.state = State::ReadHeader;
Ok(Some(msg))
}
Ok(None) => Ok(None),
Err(e) => Err(e),
};
}
Ok(None)
}
}
impl<T> Stream for Streaming<T> {
type Item = Result<T, Status>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.decode_chunk()? {
Some(item) => return Poll::Ready(Some(Ok(item))),
None => (),
}
let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
Some(Ok(d)) => Some(d),
Some(Err(e)) => {
let err: crate::Error = e.into();
debug!("decoder inner stream error: {:?}", err);
let status = Status::from_error(&*err);
Err(status)?;
break;
}
None => None,
};
if let Some(data) = chunk {
if data.remaining() > self.buf.remaining_mut() {
let amt = if data.remaining() > BUFFER_SIZE {
data.remaining()
} else {
BUFFER_SIZE
};
self.buf.reserve(amt);
}
self.buf.put(data);
} else {
if self.buf.has_remaining() {
trace!("unexpected EOF decoding stream");
Err(Status::new(
Code::Internal,
"Unexpected EOF decoding stream.".to_string(),
))?;
} else {
break;
}
}
}
if let Direction::Response(status) = self.direction {
match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
Ok(trailer) => {
if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) {
return Some(Err(e)).into();
} else {
self.trailers = trailer.map(MetadataMap::from_headers);
}
}
Err(e) => {
let err: crate::Error = e.into();
debug!("decoder inner trailers error: {:?}", err);
let status = Status::from_error(&*err);
return Some(Err(status)).into();
}
}
}
Poll::Ready(None)
}
}
impl<T> fmt::Debug for Streaming<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Streaming").finish()
}
}
#[cfg(test)]
static_assertions::assert_impl_all!(Streaming<()>: Send, Sync);