use std::time::Duration;
use ek_base::config::get_ek_settings;
use crate::{
controller::{
dispatcher::{DISPATCHER, Dispatcher},
elastic::{frequency::get_freq_tracker, progressive, recovery::recover_unique_experts},
poller::request_immediate_poll,
registry::get_registry,
routing_broadcaster::get_broadcaster,
},
proto::ek::{
object::v1::ExpertSlice,
worker::v1::{self, ExchangeResp},
},
state::{
io::{StateReader, StateReaderImpl},
models::NewNode,
writer::StateWriterImpl,
},
};
use tokio::{sync::mpsc, time::timeout};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Response, Result, Status, Streaming};
use crate::proto::ek::worker::v1::state_service_server::StateService;
pub struct StateServerImpl {}
impl StateServerImpl {
async fn process_heartbeat(
msg: &v1::ExchangeReq,
hostname: &str,
w: &StateWriterImpl,
graceful_shutdown_handled: &mut bool,
progressive_triggered: &mut bool,
stream_tx: &mpsc::Sender<Result<ExchangeResp, Status>>,
) {
let reader = StateReaderImpl::new();
let existing = reader.node_by_hostname(&msg.id).await;
let is_new = match &existing {
Ok(None) => true,
Ok(Some(_)) => {
let active = reader.active_nodes().await.unwrap_or_default();
!active.iter().any(|n| n.hostname == msg.id)
}
Err(_) => false,
};
let mut config = match existing {
Ok(Some(existing_node)) => {
log::debug!("Preserving existing config for worker {}", msg.id);
existing_node.config
}
Ok(None) => {
log::debug!("Creating new config for worker {}", msg.id);
serde_json::json!({})
}
Err(_) => {
log::warn!(
"Failed to read existing config for worker {}, creating new",
msg.id
);
serde_json::json!({})
}
};
config["addr"] = serde_json::json!(msg.addr.clone());
config["channel"] = serde_json::json!(msg.channel.clone());
if msg.rdma_tcp_port > 0 {
config["rdma_tcp_port"] = serde_json::json!(msg.rdma_tcp_port);
}
if !msg.wm_addr.is_empty() {
config["wm_addr"] = serde_json::json!(msg.wm_addr.clone());
}
if msg.mem_capacity_mb > 0 {
config["mem_capacity_mb"] = serde_json::json!(msg.mem_capacity_mb);
}
if is_new && !*progressive_triggered {
get_registry().lock().await.reregister(&msg.id);
if let Ok(Some(node)) = reader.node_by_hostname(&msg.id).await {
if let Ok(old) = reader.experts_by_node(node.id).await {
if !old.is_empty() {
log::info!(
"Clearing {} stale expert rows for returning worker {}",
old.len(),
msg.id
);
let _ = w.delete_experts_by_node(node.id).await;
}
}
}
}
let err = w
.node_upsert(NewNode {
hostname: msg.id.clone(),
device: msg.device.clone(),
config,
})
.await;
if let Err(e) = err {
log::error!("worker ping error, can not upsert node: {e}");
}
let e = w.node_update_seen(&msg.id).await;
if let Err(e) = e {
log::error!("worker ping error: {e}");
}
if !msg.loaded_experts.is_empty() {
log::debug!(
"Worker {} reports {} loaded experts",
msg.id,
msg.loaded_experts.len()
);
}
if let Err(e) = w.update_expert_load_states(&msg.id, &msg.loaded_experts).await {
log::error!("Failed to update expert load states for {}: {}", msg.id, e);
}
if !msg.expert_request_counts.is_empty() {
get_freq_tracker().add_counts(msg.expert_request_counts.clone());
}
if is_new && !*progressive_triggered {
*progressive_triggered = true;
let hn = msg.id.clone();
tokio::spawn(async move {
progressive::progressive_assign(&hn).await;
});
}
if msg.last_will && !*graceful_shutdown_handled {
*graceful_shutdown_handled = true;
log::info!(
"Worker {} signaled graceful shutdown, starting proactive migration \
(keeping in routing until stream closes)",
hostname
);
let hn = hostname.to_owned();
let tx = stream_tx.clone();
tokio::spawn(async move {
recover_unique_experts(&hn).await;
let resp = ExchangeResp {
state: None,
preemption_complete: true,
};
if let Err(e) = tx.send(Ok(resp)).await {
log::warn!("Failed to send preemption_complete to {hn}: {e}");
}
});
}
}
async fn listen_worker_ping(
mut req: tonic::Request<Streaming<v1::ExchangeReq>>,
hostname: String,
first_message: Option<v1::ExchangeReq>,
stream_tx: mpsc::Sender<Result<ExchangeResp, Status>>,
) {
let w = StateWriterImpl {};
let mut graceful_shutdown_handled = false;
let mut progressive_triggered = false;
if let Some(msg) = &first_message {
Self::process_heartbeat(
msg, &hostname, &w,
&mut graceful_shutdown_handled, &mut progressive_triggered,
&stream_tx,
).await;
}
loop {
match timeout(Duration::from_secs(get_ek_settings().controller.fault_detection.heartbeat_timeout_secs), req.get_mut().message()).await {
Ok(Ok(Some(msg))) => {
Self::process_heartbeat(
&msg, &hostname, &w,
&mut graceful_shutdown_handled, &mut progressive_triggered,
&stream_tx,
).await;
continue;
}
Ok(Ok(None)) => {
log::warn!("worker ping stream closed for worker_id={hostname}");
let _ = w.deactivate_node(&hostname).await;
get_registry().lock().await.deregister(&hostname).await;
if graceful_shutdown_handled {
log::info!("Graceful shutdown complete for {hostname}, letting poller update routing");
let _ = w.delete_experts_by_node_hostname(&hostname).await;
} else {
get_broadcaster().remove_node(&hostname).await;
let hn = hostname.clone();
tokio::spawn(async move {
recover_unique_experts(&hn).await;
let w = StateWriterImpl::new();
let _ = w.delete_experts_by_node_hostname(&hn).await;
});
}
request_immediate_poll();
return;
}
Ok(Err(e)) => {
log::error!("worker ping stream error for worker_id={hostname}, {e}");
let _ = w.deactivate_node(&hostname).await;
get_registry().lock().await.deregister(&hostname).await;
if graceful_shutdown_handled {
log::info!("Graceful shutdown complete for {hostname}, letting poller update routing");
let _ = w.delete_experts_by_node_hostname(&hostname).await;
} else {
get_broadcaster().remove_node(&hostname).await;
let hn = hostname.clone();
tokio::spawn(async move {
recover_unique_experts(&hn).await;
let w = StateWriterImpl::new();
let _ = w.delete_experts_by_node_hostname(&hn).await;
});
}
request_immediate_poll();
return;
}
Err(e) => {
log::error!("worker ping stream timeout for worker_id={hostname}, {e}");
let _ = w.deactivate_node(&hostname).await;
get_registry().lock().await.deregister(&hostname).await;
if graceful_shutdown_handled {
log::info!("Graceful shutdown complete for {hostname}, letting poller update routing");
let _ = w.delete_experts_by_node_hostname(&hostname).await;
} else {
get_broadcaster().remove_node(&hostname).await;
let hn = hostname.clone();
tokio::spawn(async move {
recover_unique_experts(&hn).await;
let w = StateWriterImpl::new();
let _ = w.delete_experts_by_node_hostname(&hn).await;
});
}
request_immediate_poll();
return;
}
}
}
}
}
#[tonic::async_trait]
impl StateService for StateServerImpl {
type ExchangeStream = ReceiverStream<Result<ExchangeResp, Status>>;
async fn exchange(
&self,
mut request: tonic::Request<Streaming<v1::ExchangeReq>>,
) -> Result<Response<Self::ExchangeStream>> {
let mut dispather_guard = DISPATCHER.lock().await;
let (stream_tx, stream_rx) = mpsc::channel(4);
let first_message = request
.get_mut()
.message()
.await?
.ok_or(Status::invalid_argument("no message"))?;
let worker_id = first_message.id.clone();
if first_message.channel == "rdma" && first_message.rdma_tcp_port > 0 {
log::info!(
"Worker {} using RDMA with TCP port {}",
worker_id,
first_message.rdma_tcp_port
);
}
let worker_id_for_ping = worker_id.clone();
let first_msg_clone = first_message.clone();
let stream_tx_for_ping = stream_tx.clone();
tokio::spawn(async move {
StateServerImpl::listen_worker_ping(request, worker_id_for_ping.clone(), Some(first_msg_clone), stream_tx_for_ping).await;
});
let mut rx = dispather_guard.subscribe(&first_message.id).await;
tokio::spawn(async move {
while let Some(t) = rx.recv().await {
let resp = ExchangeResp {
state: Some(v1::exchange_resp::ExpertWithState {
target: Some(ExpertSlice::from(t)),
}),
preemption_complete: false,
};
if let Err(e) = stream_tx.send(Ok(resp)).await {
log::error!("stream error: {e}")
};
}
});
Ok(Response::new(Self::ExchangeStream::new(stream_rx)))
}
}
impl Default for StateServerImpl {
fn default() -> Self {
Self::new()
}
}
impl StateServerImpl {
pub fn new() -> Self {
Self {}
}
}