use std::{sync::{Arc, LazyLock}, time::Duration};
use diesel::{
BelongingToDsl, ExpressionMethods, GroupedBy, SelectableHelper,
query_dsl::methods::{FilterDsl, SelectDsl},
};
use diesel_async::RunQueryDsl;
use ek_base::{config::get_ek_settings, error::EKResult};
use std::time::SystemTime;
use tokio::{sync::Notify, time::{self}};
use tonic::async_trait;
use crate::{
schema,
state::{
models::{self, NodeWithExperts},
pool,
},
};
use super::{
dispatcher::{DISPATCHER, Dispatcher},
elastic::{ELASTIC_MANAGER, frequency::get_freq_tracker},
routing_broadcaster::RoutingBroadcaster,
};
static POLL_NOW: LazyLock<Notify> = LazyLock::new(Notify::new);
pub fn request_immediate_poll() {
POLL_NOW.notify_one();
}
#[async_trait]
pub trait StatePoller {
async fn run(&mut self) -> EKResult<()>;
}
pub struct StatePollerImpl {
broadcaster: Arc<RoutingBroadcaster>,
}
#[async_trait]
impl StatePoller for StatePollerImpl {
async fn run(&mut self) -> EKResult<()> {
let poller_secs = get_ek_settings().controller.fault_detection.poller_interval_secs;
log::info!("state poller started (interval={}s)", poller_secs);
let mut interval = time::interval(Duration::from_secs(poller_secs));
loop {
tokio::select! {
_ = interval.tick() => {},
_ = POLL_NOW.notified() => {
log::info!("state poller: immediate tick requested (node removal)");
interval.reset();
},
}
log::info!("state poller tick");
let r = self.poll_state().await;
if let Err(e) = r {
log::error!("state poller error: {e}");
}
}
}
}
impl StatePollerImpl {
pub fn new(broadcaster: Arc<RoutingBroadcaster>) -> Self {
StatePollerImpl {
broadcaster,
}
}
async fn poll_state(&mut self) -> EKResult<()> {
let mut conn = pool::POOL.get().await?;
let settings = get_ek_settings();
let instance = match schema::instance::table
.filter(schema::instance::name.eq(settings.inference.instance_name.clone()))
.first::<models::Instance>(&mut conn)
.await
{
Ok(i) => i,
Err(diesel::result::Error::NotFound) => {
log::debug!(
"state poller: instance '{}' not found yet, skipping tick",
settings.inference.instance_name
);
return Ok(());
}
Err(e) => return Err(e.into()),
};
let threshold_secs = settings.controller.fault_detection.node_active_threshold_secs;
let threshold_time = SystemTime::now()
.checked_sub(std::time::Duration::from_secs(threshold_secs))
.unwrap_or(SystemTime::UNIX_EPOCH);
let nodes = schema::node::table
.filter(schema::node::last_seen_at.gt(threshold_time))
.select(models::Node::as_select())
.load(&mut conn)
.await?;
let experts = models::Expert::belonging_to(&nodes)
.select(models::Expert::as_select())
.filter(schema::expert::instance_id.eq(instance.id))
.load(&mut conn)
.await?;
let node_with_expert = experts
.grouped_by(&nodes)
.into_iter()
.zip(nodes)
.map(|(e, n)| NodeWithExperts {
experts: e,
node: n,
})
.collect::<Vec<NodeWithExperts>>();
let mut lg = DISPATCHER.lock().await;
let nodes_count = node_with_expert.len();
log::info!(nodes_count; "polling nodes");
lg.update(node_with_expert.clone()).await;
drop(lg);
self.update_routing(node_with_expert).await?;
get_freq_tracker().commit_tick();
ELASTIC_MANAGER.lock().await.run_tick().await;
Ok(())
}
fn is_expert_loaded(expert: &models::Expert) -> bool {
if expert.state.is_null() {
return true;
}
expert
.state
.get("status")
.and_then(|s| s.as_str())
.map(|s| s == "loaded")
.unwrap_or(false)
}
async fn update_routing(&self, node_with_experts: Vec<NodeWithExperts>) -> EKResult<()> {
use std::collections::HashMap;
use crate::proto::ek::control::v1::WorkerEndpoint;
let mut routing_updates: HashMap<String, Vec<WorkerEndpoint>> = HashMap::new();
let mut skipped_pending = 0;
for nwe in node_with_experts {
let node_addr = nwe
.node
.config
.get("addr")
.and_then(|a| a.as_str())
.unwrap_or("unknown")
.to_string();
let channel = nwe
.node
.config
.get("channel")
.and_then(|c| c.as_str())
.unwrap_or("grpc")
.to_string();
let rdma_tcp_port = nwe
.node
.config
.get("rdma_tcp_port")
.and_then(|p| p.as_u64())
.unwrap_or(0) as u32;
let device = nwe.node.device.clone();
let wm_addr = nwe
.node
.config
.get("wm_addr")
.and_then(|a| a.as_str())
.unwrap_or("")
.to_string();
let endpoint = WorkerEndpoint {
grpc_addr: node_addr.clone(),
channel: channel.clone(),
rdma_tcp_port,
shm_queue_prefix: nwe.node.hostname.clone(),
device,
wm_addr,
};
for expert in nwe.experts {
if !Self::is_expert_loaded(&expert) {
skipped_pending += 1;
continue;
}
let expert_id = expert.expert_id;
routing_updates
.entry(expert_id.clone())
.or_insert_with(Vec::new)
.push(endpoint.clone());
}
}
let multi_replica_count = routing_updates.values().filter(|v| v.len() > 1).count();
if skipped_pending > 0 {
log::info!(
"Updating routing table with {} loaded experts ({} with multiple replicas), {} pending experts skipped",
routing_updates.len(),
multi_replica_count,
skipped_pending
);
} else if multi_replica_count > 0 {
log::info!(
"Updating routing table with {} experts ({} with multiple replicas)",
routing_updates.len(),
multi_replica_count
);
} else {
log::info!("Updating routing table with {} experts", routing_updates.len());
}
self.broadcaster.batch_update(routing_updates).await;
Ok(())
}
}
pub fn start_poll(broadcaster: Arc<RoutingBroadcaster>) {
let mut poller = StatePollerImpl::new(broadcaster);
tokio::spawn(async move {
if let Err(e) = poller.run().await {
log::error!("state poller error {e}");
}
});
}