use bincode::{serialize, Error};
use lru::LruCache;
use rand::{AsByteSliceMut, CryptoRng, Rng};
use serde::Serialize;
use solana_sdk::hash::{self, Hash};
use solana_sdk::pubkey::Pubkey;
use solana_sdk::sanitize::{Sanitize, SanitizeError};
use solana_sdk::signature::{Keypair, Signable, Signature, Signer};
use std::borrow::Cow;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(AbiExample, Debug, Deserialize, Serialize)]
pub struct Ping<T> {
from: Pubkey,
token: T,
signature: Signature,
}
#[derive(AbiExample, Debug, Deserialize, Serialize)]
pub struct Pong {
from: Pubkey,
hash: Hash,
signature: Signature,
}
pub struct PingCache {
ttl: Duration,
pings: LruCache<(Pubkey, SocketAddr), Instant>,
pongs: LruCache<(Pubkey, SocketAddr), Instant>,
pending_cache: LruCache<Hash, (Pubkey, SocketAddr)>,
}
impl<T: Serialize> Ping<T> {
pub fn new(token: T, keypair: &Keypair) -> Result<Self, Error> {
let signature = keypair.sign_message(&serialize(&token)?);
let ping = Ping {
from: keypair.pubkey(),
token,
signature,
};
Ok(ping)
}
}
impl<T> Ping<T>
where
T: Serialize + AsByteSliceMut + Default,
{
pub fn new_rand<R>(rng: &mut R, keypair: &Keypair) -> Result<Self, Error>
where
R: Rng + CryptoRng,
{
let mut token = T::default();
rng.fill(&mut token);
Ping::new(token, keypair)
}
}
impl<T> Sanitize for Ping<T> {
fn sanitize(&self) -> Result<(), SanitizeError> {
self.from.sanitize()?;
self.signature.sanitize()
}
}
impl<T: Serialize> Signable for Ping<T> {
fn pubkey(&self) -> Pubkey {
self.from
}
fn signable_data(&self) -> Cow<[u8]> {
Cow::Owned(serialize(&self.token).unwrap())
}
fn get_signature(&self) -> Signature {
self.signature
}
fn set_signature(&mut self, signature: Signature) {
self.signature = signature;
}
}
impl Pong {
pub fn new<T: Serialize>(ping: &Ping<T>, keypair: &Keypair) -> Result<Self, Error> {
let hash = hash::hash(&serialize(&ping.token)?);
let pong = Pong {
from: keypair.pubkey(),
hash,
signature: keypair.sign_message(hash.as_ref()),
};
Ok(pong)
}
}
impl Sanitize for Pong {
fn sanitize(&self) -> Result<(), SanitizeError> {
self.from.sanitize()?;
self.hash.sanitize()?;
self.signature.sanitize()
}
}
impl Signable for Pong {
fn pubkey(&self) -> Pubkey {
self.from
}
fn signable_data(&self) -> Cow<[u8]> {
Cow::Owned(self.hash.as_ref().into())
}
fn get_signature(&self) -> Signature {
self.signature
}
fn set_signature(&mut self, signature: Signature) {
self.signature = signature;
}
}
impl PingCache {
pub fn new(ttl: Duration, cap: usize) -> Self {
Self {
ttl,
pings: LruCache::new(cap),
pongs: LruCache::new(cap),
pending_cache: LruCache::new(cap),
}
}
pub fn add(&mut self, pong: &Pong, socket: SocketAddr, now: Instant) -> bool {
let node = (pong.pubkey(), socket);
match self.pending_cache.peek(&pong.hash) {
Some(value) if *value == node => {
self.pings.pop(&node);
self.pongs.put(node, now);
self.pending_cache.pop(&pong.hash);
true
}
_ => false,
}
}
fn maybe_ping<T, F>(
&mut self,
now: Instant,
node: (Pubkey, SocketAddr),
mut pingf: F,
) -> Option<Ping<T>>
where
T: Serialize,
F: FnMut() -> Option<Ping<T>>,
{
let delay = self.ttl / 64;
match self.pings.peek(&node) {
Some(t) if now.saturating_duration_since(*t) < delay => None,
_ => {
let ping = pingf()?;
let hash = hash::hash(&serialize(&ping.token).ok()?);
self.pings.put(node, now);
self.pending_cache.put(hash, node);
Some(ping)
}
}
}
pub fn check<T, F>(
&mut self,
now: Instant,
node: (Pubkey, SocketAddr),
pingf: F,
) -> (bool, Option<Ping<T>>)
where
T: Serialize,
F: FnMut() -> Option<Ping<T>>,
{
let (check, should_ping) = match self.pongs.get(&node) {
None => (false, true),
Some(t) => {
let age = now.saturating_duration_since(*t);
if age > self.ttl {
self.pongs.pop(&node);
}
(true, age > self.ttl / 8)
}
};
let ping = if should_ping {
self.maybe_ping(now, node, pingf)
} else {
None
};
(check, ping)
}
pub(crate) fn mock_clone(&self) -> Self {
let mut clone = Self {
ttl: self.ttl,
pings: LruCache::new(self.pings.cap()),
pongs: LruCache::new(self.pongs.cap()),
pending_cache: LruCache::new(self.pending_cache.cap()),
};
for (k, v) in self.pongs.iter().rev() {
clone.pings.put(*k, *v);
}
for (k, v) in self.pongs.iter().rev() {
clone.pongs.put(*k, *v);
}
for (k, v) in self.pending_cache.iter().rev() {
clone.pending_cache.put(*k, *v);
}
clone
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::iter::repeat_with;
use std::net::{Ipv4Addr, SocketAddrV4};
type Token = [u8; 32];
#[test]
fn test_ping_pong() {
let mut rng = rand::thread_rng();
let keypair = Keypair::new();
let ping = Ping::<Token>::new_rand(&mut rng, &keypair).unwrap();
assert!(ping.verify());
assert!(ping.sanitize().is_ok());
let pong = Pong::new(&ping, &keypair).unwrap();
assert!(pong.verify());
assert!(pong.sanitize().is_ok());
assert_eq!(hash::hash(&ping.token), pong.hash);
}
#[test]
fn test_ping_cache() {
let now = Instant::now();
let mut rng = rand::thread_rng();
let ttl = Duration::from_millis(256);
let mut cache = PingCache::new(ttl, 1000);
let this_node = Keypair::new();
let keypairs: Vec<_> = repeat_with(Keypair::new).take(8).collect();
let sockets: Vec<_> = repeat_with(|| {
SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(rng.gen(), rng.gen(), rng.gen(), rng.gen()),
rng.gen(),
))
})
.take(8)
.collect();
let remote_nodes: Vec<(&Keypair, SocketAddr)> = repeat_with(|| {
let keypair = &keypairs[rng.gen_range(0, keypairs.len())];
let socket = sockets[rng.gen_range(0, sockets.len())];
(keypair, socket)
})
.take(128)
.collect();
let mut seen_nodes = HashSet::<(Pubkey, SocketAddr)>::new();
let pings: Vec<Option<Ping<Token>>> = remote_nodes
.iter()
.map(|(keypair, socket)| {
let node = (keypair.pubkey(), *socket);
let pingf = || Ping::<Token>::new_rand(&mut rng, &this_node).ok();
let (check, ping) = cache.check(now, node, pingf);
assert!(!check);
assert_eq!(seen_nodes.insert(node), ping.is_some());
ping
})
.collect();
let now = now + Duration::from_millis(1);
let panic_ping = || -> Option<Ping<Token>> { panic!("this should not happen!") };
for ((keypair, socket), ping) in remote_nodes.iter().zip(&pings) {
match ping {
None => {
let node = (keypair.pubkey(), *socket);
let (check, ping) = cache.check(now, node, panic_ping);
assert!(check);
assert!(ping.is_none());
}
Some(ping) => {
let pong = Pong::new(ping, keypair).unwrap();
assert!(cache.add(&pong, *socket, now));
}
}
}
let now = now + Duration::from_millis(1);
for (keypair, socket) in &remote_nodes {
let node = (keypair.pubkey(), *socket);
let (check, ping) = cache.check(now, node, panic_ping);
assert!(check);
assert!(ping.is_none());
}
let now = now + ttl / 8;
seen_nodes.clear();
for (keypair, socket) in &remote_nodes {
let node = (keypair.pubkey(), *socket);
let pingf = || Ping::<Token>::new_rand(&mut rng, &this_node).ok();
let (check, ping) = cache.check(now, node, pingf);
assert!(check);
assert_eq!(seen_nodes.insert(node), ping.is_some());
}
let now = now + Duration::from_millis(1);
for (keypair, socket) in &remote_nodes {
let node = (keypair.pubkey(), *socket);
let (check, ping) = cache.check(now, node, panic_ping);
assert!(check);
assert!(ping.is_none());
}
let now = now + ttl;
seen_nodes.clear();
for (keypair, socket) in &remote_nodes {
let node = (keypair.pubkey(), *socket);
let pingf = || Ping::<Token>::new_rand(&mut rng, &this_node).ok();
let (check, ping) = cache.check(now, node, pingf);
if seen_nodes.insert(node) {
assert!(check);
assert!(ping.is_some());
} else {
assert!(!check);
assert!(ping.is_none());
}
}
let now = now + Duration::from_millis(1);
for (keypair, socket) in &remote_nodes {
let node = (keypair.pubkey(), *socket);
let (check, ping) = cache.check(now, node, panic_ping);
assert!(!check);
assert!(ping.is_none());
}
let now = now + ttl / 64;
seen_nodes.clear();
for (keypair, socket) in &remote_nodes {
let node = (keypair.pubkey(), *socket);
let pingf = || Ping::<Token>::new_rand(&mut rng, &this_node).ok();
let (check, ping) = cache.check(now, node, pingf);
assert!(!check);
assert_eq!(seen_nodes.insert(node), ping.is_some());
}
}
}