use std::{
collections::{BTreeMap, HashMap},
hash::{DefaultHasher, Hash, Hasher},
sync::{Arc, LazyLock},
};
use tokio::sync::{
Mutex,
mpsc::{self, Receiver, Sender},
};
use tonic::async_trait;
use crate::{
controller::elastic::frequency::get_freq_tracker,
state::models::{Expert, NodeWithExperts},
};
#[async_trait]
pub trait Dispatcher {
async fn update(&mut self, state: Vec<NodeWithExperts>);
async fn subscribe(&mut self, hostname: &str) -> Receiver<Vec<Expert>>;
async fn unsubscribe(&mut self, hostname: &str);
}
pub struct DispatcherImpl {
ch_store: BTreeMap<String, Sender<Vec<Expert>>>,
last_fingerprint: HashMap<String, u64>,
}
pub static DISPATCHER: LazyLock<Arc<Mutex<DispatcherImpl>>> = LazyLock::new(|| {
let inner = DispatcherImpl::new();
Arc::new(Mutex::new(inner))
});
fn fingerprint(experts: &[Expert]) -> u64 {
let mut ids: Vec<&str> = experts.iter().map(|e| e.expert_id.as_str()).collect();
ids.sort_unstable();
let mut h = DefaultHasher::new();
for id in ids {
id.hash(&mut h);
}
h.finish()
}
impl DispatcherImpl {
fn new() -> Self {
Self {
ch_store: BTreeMap::new(),
last_fingerprint: HashMap::new(),
}
}
}
#[async_trait]
impl Dispatcher for DispatcherImpl {
async fn update(&mut self, state: Vec<NodeWithExperts>) {
let mut replica_count: HashMap<&str, usize> = HashMap::new();
for nwe in &state {
for expert in &nwe.experts {
*replica_count.entry(expert.expert_id.as_str()).or_default() += 1;
}
}
for data in &state {
let node = &data.node;
let mut experts: Vec<Expert> = data
.experts
.iter()
.filter(|e| {
e.state.get("status").and_then(|s| s.as_str()) != Some("scheduled")
})
.cloned()
.collect();
let freq = get_freq_tracker();
if freq.has_data() {
experts.sort_by(|a, b| {
let ra = freq.rate_in_window(&a.expert_id);
let rb = freq.rate_in_window(&b.expert_id);
if ra == 0 && rb == 0 {
let ca = replica_count.get(a.expert_id.as_str()).copied().unwrap_or(0);
let cb = replica_count.get(b.expert_id.as_str()).copied().unwrap_or(0);
cb.cmp(&ca)
} else {
rb.cmp(&ra)
}
});
} else {
experts.sort_by(|a, b| {
let ca = replica_count.get(a.expert_id.as_str()).copied().unwrap_or(0);
let cb = replica_count.get(b.expert_id.as_str()).copied().unwrap_or(0);
cb.cmp(&ca)
});
}
let fp = fingerprint(&experts);
if self.last_fingerprint.get(&node.hostname) == Some(&fp) {
continue;
}
self.last_fingerprint.insert(node.hostname.clone(), fp);
if let Some(ch) = self.ch_store.get(&node.hostname)
&& let Err(e) = ch.send(experts).await
{
log::error!(
"Failed to send expert update to channel for hostname: {} err: {}",
node.hostname,
e
);
}
}
}
async fn subscribe(&mut self, hostname: &str) -> Receiver<Vec<Expert>> {
let (tx, rx) = mpsc::channel(10);
self.ch_store.insert(hostname.to_string(), tx);
self.last_fingerprint.remove(hostname);
rx
}
async fn unsubscribe(&mut self, hostname: &str) {
self.ch_store.remove(hostname);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::models::Expert;
fn make_expert(id: &str) -> Expert {
Expert {
id: 0,
instance_id: 0,
node_id: 0,
expert_id: id.to_string(),
replica: 0,
state: serde_json::Value::Null,
}
}
#[test]
fn fingerprint_order_independent() {
let a = vec![make_expert("m/l0-e1"), make_expert("m/l0-e0")];
let b = vec![make_expert("m/l0-e0"), make_expert("m/l0-e1")];
assert_eq!(fingerprint(&a), fingerprint(&b));
}
#[test]
fn fingerprint_differs_on_change() {
let a = vec![make_expert("m/l0-e0"), make_expert("m/l0-e1")];
let b = vec![make_expert("m/l0-e0"), make_expert("m/l0-e2")];
assert_ne!(fingerprint(&a), fingerprint(&b));
}
#[test]
fn fingerprint_empty() {
assert_eq!(fingerprint(&[]), fingerprint(&[]));
}
}
impl DispatcherImpl {
pub async fn trigger_worker(&self, hostname: &str, experts: Vec<Expert>) {
if let Some(ch) = self.ch_store.get(hostname) {
if let Err(e) = ch.send(experts).await {
log::warn!("trigger_worker: failed to send to {hostname}: {e}");
}
} else {
log::warn!("trigger_worker: no active channel for {hostname}");
}
}
}