use crate::error;
use futures_core::ready;
use futures_util::future::{self, TryFutureExt};
use pin_project::pin_project;
use rand::{rngs::SmallRng, SeedableRng};
use std::marker::PhantomData;
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use tower_discover::{Change, Discover};
use tower_load::Load;
use tower_ready_cache::{error::Failed, ReadyCache};
use tower_service::Service;
use tracing::{debug, trace};
pub struct Balance<D: Discover, Req> {
discover: D,
services: ReadyCache<D::Key, D::Service, Req>,
ready_index: Option<usize>,
rng: SmallRng,
_req: PhantomData<Req>,
}
impl<D: Discover, Req> fmt::Debug for Balance<D, Req>
where
D: fmt::Debug,
D::Key: fmt::Debug,
D::Service: fmt::Debug,
Req: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Balance")
.field("discover", &self.discover)
.field("services", &self.services)
.finish()
}
}
#[pin_project]
#[derive(Debug)]
struct UnreadyService<K, S, Req> {
key: Option<K>,
#[pin]
cancel: oneshot::Receiver<()>,
service: Option<S>,
_req: PhantomData<Req>,
}
enum Error<E> {
Inner(E),
Canceled,
}
impl<D, Req> Balance<D, Req>
where
D: Discover,
D::Service: Service<Req>,
<D::Service as Service<Req>>::Error: Into<error::Error>,
{
pub fn new(discover: D, rng: SmallRng) -> Self {
Self {
rng,
discover,
services: ReadyCache::default(),
ready_index: None,
_req: PhantomData,
}
}
pub fn from_entropy(discover: D) -> Self {
Self::new(discover, SmallRng::from_entropy())
}
pub fn len(&self) -> usize {
self.services.len()
}
}
impl<D, Req> Balance<D, Req>
where
D: Discover + Unpin,
D::Key: Clone,
D::Error: Into<error::Error>,
D::Service: Service<Req> + Load,
<D::Service as Load>::Metric: std::fmt::Debug,
<D::Service as Service<Req>>::Error: Into<error::Error>,
{
fn update_pending_from_discover(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), error::Discover>> {
debug!("updating from discover");
loop {
match ready!(Pin::new(&mut self.discover).poll_discover(cx))
.map_err(|e| error::Discover(e.into()))?
{
Change::Remove(key) => {
trace!("remove");
self.services.evict(&key);
}
Change::Insert(key, svc) => {
trace!("insert");
self.services.push(key, svc);
}
}
}
}
fn promote_pending_to_ready(&mut self, cx: &mut Context<'_>) {
loop {
match self.services.poll_pending(cx) {
Poll::Ready(Ok(())) => {
debug_assert_eq!(self.services.pending_len(), 0);
break;
}
Poll::Pending => {
debug_assert!(self.services.pending_len() > 0);
break;
}
Poll::Ready(Err(error)) => {
debug!(%error, "dropping failed endpoint");
}
}
}
trace!(
ready = %self.services.ready_len(),
pending = %self.services.pending_len(),
"poll_unready"
);
}
fn p2c_ready_index(&mut self) -> Option<usize> {
match self.services.ready_len() {
0 => None,
1 => Some(0),
len => {
let idxs = rand::seq::index::sample(&mut self.rng, len, 2);
let aidx = idxs.index(0);
let bidx = idxs.index(1);
debug_assert_ne!(aidx, bidx, "random indices must be distinct");
let aload = self.ready_index_load(aidx);
let bload = self.ready_index_load(bidx);
let chosen = if aload <= bload { aidx } else { bidx };
trace!(
a.index = aidx,
a.load = ?aload,
b.index = bidx,
b.load = ?bload,
chosen = if chosen == aidx { "a" } else { "b" },
"p2c",
);
Some(chosen)
}
}
}
fn ready_index_load(&self, index: usize) -> <D::Service as Load>::Metric {
let (_, svc) = self.services.get_ready_index(index).expect("invalid index");
svc.load()
}
pub(crate) fn discover_mut(&mut self) -> &mut D {
&mut self.discover
}
}
impl<D, Req> Service<Req> for Balance<D, Req>
where
D: Discover + Unpin,
D::Key: Clone,
D::Error: Into<error::Error>,
D::Service: Service<Req> + Load,
<D::Service as Load>::Metric: std::fmt::Debug,
<D::Service as Service<Req>>::Error: Into<error::Error>,
{
type Response = <D::Service as Service<Req>>::Response;
type Error = error::Error;
type Future = future::MapErr<
<D::Service as Service<Req>>::Future,
fn(<D::Service as Service<Req>>::Error) -> error::Error,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let _ = self.update_pending_from_discover(cx)?;
self.promote_pending_to_ready(cx);
loop {
if let Some(index) = self.ready_index.take() {
match self.services.check_ready_index(cx, index) {
Ok(true) => {
self.ready_index = Some(index);
return Poll::Ready(Ok(()));
}
Ok(false) => {
trace!("ready service became unavailable");
}
Err(Failed(_, error)) => {
debug!(%error, "endpoint failed");
}
}
}
self.ready_index = self.p2c_ready_index();
if self.ready_index.is_none() {
debug_assert_eq!(self.services.ready_len(), 0);
return Poll::Pending;
}
}
}
fn call(&mut self, request: Req) -> Self::Future {
let index = self.ready_index.take().expect("called before ready");
self.services
.call_ready_index(index, request)
.map_err(Into::into)
}
}
impl<K, S: Service<Req>, Req> Future for UnreadyService<K, S, Req> {
type Output = Result<(K, S), (K, Error<S::Error>)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(Ok(())) = this.cancel.poll(cx) {
let key = this.key.take().expect("polled after ready");
return Poll::Ready(Err((key, Error::Canceled)));
}
let res = ready!(this
.service
.as_mut()
.expect("poll after ready")
.poll_ready(cx));
let key = this.key.take().expect("polled after ready");
let svc = this.service.take().expect("polled after ready");
match res {
Ok(()) => Poll::Ready(Ok((key, svc))),
Err(e) => Poll::Ready(Err((key, Error::Inner(e)))),
}
}
}