mod core;
use std::sync::{
atomic::{AtomicBool, Ordering},
{Arc, LazyLock, Mutex, OnceLock},
};
use std::time;
use std::time::Duration;
use std::{env, panic};
use ek_base::tracing::grpc::OTelGrpcServerMiddleware;
use state::StateInspector;
use tokio::select;
use tokio::signal;
use tokio_util::sync::CancellationToken;
mod manager;
pub mod server;
pub mod state;
pub mod x;
use crate::controller::registry::{ShmqWorkerReq, ShmqWorkerResp};
use crate::metrics::spawn_metrics_server;
use crate::proto::ek::worker::v1::computation_service_server::ComputationServiceServer;
use crate::shmq::{RdmaEndpointServer, ShmQueue, rdma_impl::RdmaQueue};
use crate::worker::core::EKInstanceGateSync;
use crate::worker::server::BasicExpertImpl;
use crate::x::get_graceful_shutdown_ch;
use super::worker::state::StateClient;
use ek_base::{config::get_ek_settings, error::EKResult};
static RDMA_REQ_QUEUE: OnceLock<Arc<Mutex<RdmaQueue<ShmqWorkerReq>>>> = OnceLock::new();
static RDMA_RESP_QUEUE: OnceLock<Arc<Mutex<RdmaQueue<ShmqWorkerResp>>>> = OnceLock::new();
static RDMA_CONNECTION_STATUS: AtomicBool = AtomicBool::new(false);
static RDMA_TCP_PORT: OnceLock<u16> = OnceLock::new();
static WORKER_PARALLEL: LazyLock<usize> = LazyLock::new(|| {
env::var("EK_WORKER_PARALLEL")
.map(|v| v.parse().unwrap_or(1))
.unwrap_or(1)
});
pub fn get_rdma_req_queue() -> Option<&'static Arc<Mutex<RdmaQueue<ShmqWorkerReq>>>> {
RDMA_REQ_QUEUE.get()
}
pub fn get_rdma_resp_queue() -> Option<&'static Arc<Mutex<RdmaQueue<ShmqWorkerResp>>>> {
RDMA_RESP_QUEUE.get()
}
pub fn is_rdma_queue_connected() -> bool {
RDMA_CONNECTION_STATUS.load(Ordering::Relaxed)
}
pub fn update_rdma_connection_status(connected: bool) {
RDMA_CONNECTION_STATUS.store(connected, Ordering::Relaxed);
}
pub fn close_rdma_queues() {
if let Some(req_queue) = RDMA_REQ_QUEUE.get() {
let mut rq = req_queue.lock().unwrap();
rq.disconnect();
}
if let Some(resp_queue) = RDMA_RESP_QUEUE.get() {
let mut rq = resp_queue.lock().unwrap();
rq.disconnect();
}
update_rdma_connection_status(false);
}
pub fn get_rdma_tcp_port() -> Option<u16> {
RDMA_TCP_PORT.get().copied()
}
async fn create_rdma_queues_with_tcp_server(
poison: Arc<Mutex<bool>>,
) -> EKResult<(u16, std::thread::JoinHandle<()>)> {
let req_queue = RdmaQueue::<ShmqWorkerReq>::new(None, 256, false)?;
let resp_queue = RdmaQueue::<ShmqWorkerResp>::new(None, 256, true)?;
let req_queue_arc = Arc::new(Mutex::new(req_queue));
let resp_queue_arc = Arc::new(Mutex::new(resp_queue));
RDMA_REQ_QUEUE.set(req_queue_arc.clone()).map_err(|_| {
ek_base::error::EKError::InvalidInput("Failed to set RDMA request queue".into())
})?;
RDMA_RESP_QUEUE.set(resp_queue_arc.clone()).map_err(|_| {
ek_base::error::EKError::InvalidInput("Failed to set RDMA response queue".into())
})?;
let endpoint_server = RdmaEndpointServer::new(req_queue_arc, resp_queue_arc, poison)
.map_err(|e| ek_base::error::EKError::IoError(e))?;
let tcp_port = endpoint_server.port();
RDMA_TCP_PORT
.set(tcp_port)
.map_err(|_| ek_base::error::EKError::InvalidInput("Failed to set RDMA TCP port".into()))?;
let endpoint_server_handle = std::thread::spawn(move || match endpoint_server.start() {
Ok(()) => {
log::info!("RDMA TCP endpoint server completed successfully");
}
Err(e) => {
log::error!("RDMA TCP endpoint server failed: {}", e);
}
});
Ok((tcp_port, endpoint_server_handle))
}
pub async fn worker_main() -> EKResult<()> {
let settings = get_ek_settings();
spawn_metrics_server(&settings.worker.metrics);
let token = CancellationToken::new();
let cli_cancel = token.clone();
let poison = Arc::new(Mutex::new(false));
let state_inspect = StateInspector::spawn();
let async_srv;
let mut sync_srvs = Vec::new();
tch::set_num_threads(*WORKER_PARALLEL as _);
let rdma_tcp_port: Option<u16> = if settings.worker.channel == "rdma" {
match create_rdma_queues_with_tcp_server(poison.clone()).await {
Ok((tcp_port, handle)) => {
log::info!("RDMA queues and TCP server created successfully");
sync_srvs.push(handle);
Some(tcp_port)
}
Err(e) => {
log::error!("Failed to create RDMA queues: {e}");
return Err(e);
}
}
} else {
None
};
let cli = tokio::task::spawn(async move {
let worker_id = x::get_worker_id();
log::info!("ek hostname: {worker_id:}");
let control_endpoint = x::get_controller_addr();
log::info!("control endpoint {:}", control_endpoint.uri());
let mut state_client =
StateClient::new_with_rdma_tcp_port(control_endpoint, &worker_id, rdma_tcp_port);
if let Err(e) = state_client.run(cli_cancel).await {
log::error!("state client error {e:}");
}
});
match settings.worker.channel.as_str() {
"grpc" => {
let srv = tokio::task::spawn(async move {
let server = BasicExpertImpl::new();
let settings = &get_ek_settings().worker;
let addr = format!("{}:{}", settings.listen, settings.ports.main)
.parse()
.unwrap();
log::info!("worker server listening on {addr}");
let layer = tower::ServiceBuilder::new()
.layer_fn(OTelGrpcServerMiddleware::new)
.into_inner();
let err = tonic::transport::Server::builder()
.layer(layer)
.add_service(
ComputationServiceServer::new(server)
.max_decoding_message_size(200 * 1024 * 1024)
.max_encoding_message_size(200 * 1024 * 1024),
)
.serve(addr)
.await;
if let Err(e) = err {
log::error!("server error {e:?}");
}
});
async_srv = srv;
}
"shm" => {
let node_name = x::get_worker_id();
let recv_channel = loop {
if let Some(channel) =
ShmQueue::<ShmqWorkerReq>::open(&format!("ek-shmq-req-{}", node_name))
{
break Arc::new(Mutex::new(channel));
}
};
let send_channel = loop {
if let Some(channel) =
ShmQueue::<ShmqWorkerResp>::open(&format!("ek-shmq-resp-{}", node_name))
{
break Arc::new(Mutex::new(channel));
}
};
let thread_count: usize = env::var("EK_WORKER_THREADS")
.map(|v| v.parse().unwrap_or(1))
.unwrap_or(1);
for _ in 0..thread_count {
let recv_channel = recv_channel.clone();
let send_channel = send_channel.clone();
let gate = EKInstanceGateSync::default();
let poison = poison.clone();
let srv = std::thread::spawn(move || {
'main: loop {
let req = loop {
if *poison.lock().unwrap() {
break 'main;
}
if let Ok(req) = recv_channel.lock().unwrap().recv() {
break req;
}
std::thread::sleep(Duration::from_micros(100));
};
log::debug!(
"received request: id={} expert={}",
req.id(),
req.expert_id()
);
let now = time::Instant::now();
let expert_id = req.expert_id();
let input_tensor = req.input_tensor();
let output_tensor = loop {
match gate.forward_sync_core(&expert_id, input_tensor) {
Ok(result) => {
log::debug!(
"forward_sync_core completed for expert={}",
expert_id
);
break result;
}
Err(err) => log::warn!("forward_sync_core {err}, retrying..."),
}
std::thread::sleep(Duration::from_secs(1));
};
let resp = ShmqWorkerResp::new(req.id(), output_tensor);
while send_channel.lock().unwrap().send(&resp).is_err() {
log::warn!("send_channel full, retrying...");
std::thread::sleep(Duration::from_micros(100));
}
log::info!(
"request id={} expert={} processed in {}us",
req.id(),
req.expert_id(),
now.elapsed().as_micros(),
);
}
});
sync_srvs.push(srv);
}
let token = token.clone();
async_srv = tokio::spawn(async move {
select! {
_ = token.cancelled() => {
log::info!("async service cancelled");
}
}
});
}
"rdma" => {
let recv_channel = get_rdma_req_queue()
.ok_or_else(|| {
ek_base::error::EKError::NotFound("RDMA request queue not found".into())
})?
.clone();
let send_channel = get_rdma_resp_queue()
.ok_or_else(|| {
ek_base::error::EKError::NotFound("RDMA response queue not found".into())
})?
.clone();
let thread_count: usize = env::var("EK_WORKER_THREADS")
.map(|v| v.parse().unwrap_or(1))
.unwrap_or(1);
for idx in 0..thread_count {
let recv_channel = recv_channel.clone();
let send_channel = send_channel.clone();
let gate = EKInstanceGateSync::default();
let poison = poison.clone();
let srv = std::thread::spawn(move || {
'main: loop {
let req = loop {
if *poison.lock().unwrap() {
break 'main;
}
if !is_rdma_queue_connected() {
std::thread::sleep(Duration::from_secs(2));
}
match recv_channel.lock().unwrap().recv() {
Ok(req) => break req,
Err(_) => {
std::thread::sleep(Duration::from_micros(100));
continue;
}
}
};
log::debug!(
"thread {} received RDMA request: id={} expert={}",
idx,
req.id(),
req.expert_id()
);
let now = time::Instant::now();
let expert_id = req.expert_id();
let input_tensor = req.input_tensor();
let output_tensor = loop {
match gate.forward_sync_core(&expert_id, input_tensor) {
Ok(result) => {
log::debug!(
"forward_sync_core completed for expert={}",
expert_id
);
break result;
}
Err(err) => log::warn!("forward_sync_core {err}, retrying..."),
}
std::thread::sleep(Duration::from_secs(1));
};
let resp = ShmqWorkerResp::new(req.id(), output_tensor);
let send_start = time::Instant::now();
while send_channel.lock().unwrap().send(&resp).is_err() {
log::warn!("RDMA send_channel full, retrying...");
std::thread::sleep(Duration::from_micros(100));
}
log::info!(
"thread {} RDMA request id={} expert={} processed in {}us, send wait {}us",
idx,
req.id(),
req.expert_id(),
now.elapsed().as_micros(),
send_start.elapsed().as_micros(),
);
}
});
sync_srvs.push(srv);
}
let token = token.clone();
async_srv = tokio::spawn(async move {
select! {
_ = token.cancelled() => {
log::info!("async service cancelled");
}
}
});
}
_ => {
panic!("Unsupported worker channel: {}", settings.worker.channel);
}
}
select! {
_ = cli => { },
_ = async_srv => { },
_ = state_inspect => { },
_ = signal::ctrl_c() => {
log::info!("ctrl-c signal received, shutting down");
*poison.lock().unwrap() = true;
token.clone().cancel();
let(_,rx) = get_graceful_shutdown_ch();
rx.lock().await.recv().await;
for srv in sync_srvs {
srv.join().unwrap();
}
log::info!("graceful shutdown channel received, shutting down now");
}
};
Ok(())
}