use super::{Instrument, InstrumentFuture, NoInstrument};
use crate::Load;
use futures_core::ready;
use log::trace;
use pin_project::pin_project;
use std::{
pin::Pin,
task::{Context, Poll},
};
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use tokio::time::Instant;
use tower_discover::{Change, Discover};
use tower_service::Service;
#[derive(Debug)]
pub struct PeakEwma<S, I = NoInstrument> {
service: S,
decay_ns: f64,
rtt_estimate: Arc<Mutex<RttEstimate>>,
instrument: I,
}
#[pin_project]
#[derive(Debug)]
pub struct PeakEwmaDiscover<D, I = NoInstrument> {
#[pin]
discover: D,
decay_ns: f64,
default_rtt: Duration,
instrument: I,
}
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
pub struct Cost(f64);
#[derive(Debug)]
pub struct Handle {
sent_at: Instant,
decay_ns: f64,
rtt_estimate: Arc<Mutex<RttEstimate>>,
}
#[derive(Debug)]
struct RttEstimate {
update_at: Instant,
rtt_ns: f64,
}
const NANOS_PER_MILLI: f64 = 1_000_000.0;
impl<D, I> PeakEwmaDiscover<D, I> {
pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, instrument: I) -> Self
where
D: Discover,
D::Service: Service<Request>,
I: Instrument<Handle, <D::Service as Service<Request>>::Response>,
{
PeakEwmaDiscover {
discover,
decay_ns: nanos(decay),
default_rtt,
instrument,
}
}
}
impl<D, I> Discover for PeakEwmaDiscover<D, I>
where
D: Discover,
I: Clone,
{
type Key = D::Key;
type Service = PeakEwma<D::Service, I>;
type Error = D::Error;
fn poll_discover(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Change<D::Key, Self::Service>, D::Error>> {
let this = self.project();
let change = match ready!(this.discover.poll_discover(cx))? {
Change::Remove(k) => Change::Remove(k),
Change::Insert(k, svc) => {
let peak_ewma = PeakEwma::new(
svc,
*this.default_rtt,
*this.decay_ns,
this.instrument.clone(),
);
Change::Insert(k, peak_ewma)
}
};
Poll::Ready(Ok(change))
}
}
impl<S, I> PeakEwma<S, I> {
fn new(service: S, default_rtt: Duration, decay_ns: f64, instrument: I) -> Self {
Self {
service,
decay_ns,
rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
instrument,
}
}
fn handle(&self) -> Handle {
Handle {
decay_ns: self.decay_ns,
sent_at: Instant::now(),
rtt_estimate: self.rtt_estimate.clone(),
}
}
}
impl<S, I, Request> Service<Request> for PeakEwma<S, I>
where
S: Service<Request>,
I: Instrument<Handle, S::Response>,
{
type Response = I::Output;
type Error = S::Error;
type Future = InstrumentFuture<S::Future, I, Handle>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
InstrumentFuture::new(
self.instrument.clone(),
self.handle(),
self.service.call(req),
)
}
}
impl<S, I> Load for PeakEwma<S, I> {
type Metric = Cost;
fn load(&self) -> Self::Metric {
let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
let estimate = self.update_estimate();
let cost = Cost(estimate * f64::from(pending + 1));
trace!(
"load estimate={:.0}ms pending={} cost={:?}",
estimate / NANOS_PER_MILLI,
pending,
cost,
);
cost
}
}
impl<S, I> PeakEwma<S, I> {
fn update_estimate(&self) -> f64 {
let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
rtt.decay(self.decay_ns)
}
}
impl RttEstimate {
fn new(rtt_ns: f64) -> Self {
debug_assert!(0.0 < rtt_ns, "rtt must be positive");
Self {
rtt_ns,
update_at: Instant::now(),
}
}
fn decay(&mut self, decay_ns: f64) -> f64 {
let now = Instant::now();
self.update(now, now, decay_ns)
}
fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
debug_assert!(
sent_at <= recv_at,
"recv_at={:?} after sent_at={:?}",
recv_at,
sent_at
);
let rtt = nanos(recv_at - sent_at);
let now = Instant::now();
debug_assert!(
self.update_at <= now,
"update_at={:?} in the future",
self.update_at
);
self.rtt_ns = if self.rtt_ns < rtt {
trace!(
"update peak rtt={}ms prior={}ms",
rtt / NANOS_PER_MILLI,
self.rtt_ns / NANOS_PER_MILLI,
);
rtt
} else {
let elapsed = nanos(now - self.update_at);
let decay = (-elapsed / decay_ns).exp();
let recency = 1.0 - decay;
let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
trace!(
"update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
rtt / NANOS_PER_MILLI,
self.rtt_ns - next_estimate,
next_estimate / NANOS_PER_MILLI,
);
next_estimate
};
self.update_at = now;
self.rtt_ns
}
}
impl Drop for Handle {
fn drop(&mut self) {
let recv_at = Instant::now();
if let Ok(mut rtt) = self.rtt_estimate.lock() {
rtt.update(self.sent_at, recv_at, self.decay_ns);
}
}
}
fn nanos(d: Duration) -> f64 {
const NANOS_PER_SEC: u64 = 1_000_000_000;
let n = f64::from(d.subsec_nanos());
let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
n + s
}
#[cfg(test)]
mod tests {
use futures_util::future;
use std::time::Duration;
use tokio::time;
use tokio_test::{assert_ready, assert_ready_ok, task};
use super::*;
struct Svc;
impl Service<()> for Svc {
type Response = ();
type Error = ();
type Future = future::Ready<Result<(), ()>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, (): ()) -> Self::Future {
future::ok(())
}
}
#[tokio::test]
async fn default_decay() {
time::pause();
let svc = PeakEwma::new(
Svc,
Duration::from_millis(10),
NANOS_PER_MILLI * 1_000.0,
NoInstrument,
);
let Cost(load) = svc.load();
assert_eq!(load, 10.0 * NANOS_PER_MILLI);
time::advance(Duration::from_millis(100)).await;
let Cost(load) = svc.load();
assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
time::advance(Duration::from_millis(100)).await;
let Cost(load) = svc.load();
assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
}
#[tokio::test]
async fn compound_decay() {
time::pause();
let mut svc = PeakEwma::new(
Svc,
Duration::from_millis(20),
NANOS_PER_MILLI * 1_000.0,
NoInstrument,
);
assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
time::advance(Duration::from_millis(100)).await;
let mut rsp0 = task::spawn(svc.call(()));
assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
time::advance(Duration::from_millis(100)).await;
let mut rsp1 = task::spawn(svc.call(()));
assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
time::advance(Duration::from_millis(100)).await;
let () = assert_ready_ok!(rsp0.poll());
assert_eq!(svc.load(), Cost(400_000_000.0));
time::advance(Duration::from_millis(100)).await;
let () = assert_ready_ok!(rsp1.poll());
assert_eq!(svc.load(), Cost(200_000_000.0));
time::advance(Duration::from_secs(1)).await;
assert!(svc.load() < Cost(100_000_000.0));
time::advance(Duration::from_secs(10)).await;
assert!(svc.load() < Cost(100_000.0));
}
#[test]
fn nanos() {
assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
assert_eq!(
super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
18446744074709553000.0
);
}
}