use std::{
collections::HashMap,
sync::{
Arc, OnceLock,
atomic::{AtomicUsize, Ordering},
},
};
use crate::{
shmq::{GeneralShmQueueBytes, RdmaEndpointClient, ShmQueue, rdma_impl::RdmaQueue},
state::io::{StateReader, StateReaderImpl},
};
use ek_base::{
error::{EKError, EKResult},
tracing::grpc::OTelGrpcClientMiddleware,
};
use ndarray_rand::rand;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use tonic::transport::Channel;
use tower::ServiceBuilder;
use url::Url;
const MAX_RETRIES: u32 = 3;
const BASE_DELAY_MS: u64 = 1000;
pub type ExpertId = String;
pub type ExpertIdRef<'a> = &'a str;
pub type LocalShmChannel = (
Arc<Mutex<ShmQueue<'static, ShmqWorkerReq>>>,
Arc<Mutex<ShmQueue<'static, ShmqWorkerResp>>>,
);
pub type RdmaChannel = (
Arc<Mutex<RdmaQueue<ShmqWorkerReq>>>,
Arc<Mutex<RdmaQueue<ShmqWorkerResp>>>,
);
#[derive(Clone)]
pub enum ExpertClient {
Grpc(OTelGrpcClientMiddleware),
Shm(LocalShmChannel),
Rdma(RdmaChannel),
}
impl std::fmt::Debug for ExpertClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExpertClient::Grpc(_) => write!(f, "ExpertClient::Grpc(..)"),
ExpertClient::Shm(_) => write!(f, "ExpertClient::Shm(..)"),
ExpertClient::Rdma(_) => write!(f, "ExpertClient::Rdma(..)"),
}
}
}
impl ExpertClient {
pub fn into_grpc_client(self) -> Option<OTelGrpcClientMiddleware> {
match self {
ExpertClient::Grpc(client) => Some(client),
ExpertClient::Shm(_) => None,
ExpertClient::Rdma(_) => None,
}
}
pub fn into_shm_channels(self) -> Option<LocalShmChannel> {
match self {
ExpertClient::Grpc(_) => None,
ExpertClient::Shm(channels) => Some(channels),
ExpertClient::Rdma(_) => None,
}
}
pub fn into_rdma_channels(self) -> Option<RdmaChannel> {
match self {
ExpertClient::Grpc(_) => None,
ExpertClient::Shm(_) => None,
ExpertClient::Rdma(channels) => Some(channels),
}
}
pub fn is_grpc(&self) -> bool {
matches!(self, ExpertClient::Grpc(_))
}
pub fn is_shm(&self) -> bool {
matches!(self, ExpertClient::Shm(_))
}
pub fn is_rdma(&self) -> bool {
matches!(self, ExpertClient::Rdma(_))
}
}
#[async_trait::async_trait]
pub trait ExpertRegistry {
async fn select(&mut self, eid: ExpertIdRef<'_>) -> EKResult<ExpertClient>;
async fn reset(&mut self) -> EKResult<()>;
async fn deregister(&mut self, host_id: &str);
}
#[derive(Clone)]
struct GrpcChannelMeta {
host_id: String,
ch: Channel,
}
#[derive(Clone)]
struct ShmChannelMeta {
host_id: String,
ch: LocalShmChannel,
}
#[derive(Clone)]
struct RdmaChannelMeta {
host_id: String,
ch: RdmaChannel,
}
#[derive(Clone)]
enum ChannelMeta {
Grpc(GrpcChannelMeta),
Shm(ShmChannelMeta),
Rdma(RdmaChannelMeta),
}
#[derive(Clone)]
struct RdmaNodeConnection {
req_queue: Arc<Mutex<RdmaQueue<ShmqWorkerReq>>>,
resp_queue: Arc<Mutex<RdmaQueue<ShmqWorkerResp>>>,
connected: bool,
}
pub struct ExpertRegistryImpl {
eid2channels: HashMap<ExpertId, Vec<ChannelMeta>>,
all_shm_channels: HashMap<String, LocalShmChannel>,
all_rdma_connections: HashMap<String, RdmaNodeConnection>,
reader: Box<dyn StateReader + Send + Sync>,
}
#[async_trait::async_trait]
impl ExpertRegistry for ExpertRegistryImpl {
async fn reset(&mut self) -> EKResult<()> {
self.inner_reset().await
}
async fn select(&mut self, eid: ExpertIdRef<'_>) -> EKResult<ExpertClient> {
let ch = self.inner_select(eid).await?;
match ch {
ChannelMeta::Grpc(meta) => {
let client = ServiceBuilder::new()
.layer_fn(OTelGrpcClientMiddleware::new)
.service(meta.ch.clone());
Ok(ExpertClient::Grpc(client))
}
ChannelMeta::Shm(meta) => Ok(ExpertClient::Shm(meta.ch.clone())),
ChannelMeta::Rdma(meta) => Ok(ExpertClient::Rdma(meta.ch.clone())),
}
}
async fn deregister(&mut self, host_id: &str) {
self.inner_deregister(host_id).await;
}
}
impl ExpertRegistryImpl {
async fn inner_reset(&mut self) -> EKResult<()> {
self.eid2channels.clear();
Ok(())
}
async fn inner_select(&mut self, eid: ExpertIdRef<'_>) -> EKResult<ChannelMeta> {
let channels = self.eid2channels.get(eid);
if let Some(channels) = channels {
if channels.is_empty() {
return self.create_then_select_channel(eid).await;
}
self.select_random(eid).await
} else {
self.create_then_select_channel(eid).await
}
}
async fn select_random(&mut self, eid: ExpertIdRef<'_>) -> EKResult<ChannelMeta> {
let channels = self.eid2channels.get(eid);
if let Some(channels) = channels {
if channels.is_empty() {
return self.create_then_select_channel(eid).await;
}
let idx = rand::random::<usize>() % channels.len();
Ok(channels[idx].clone())
} else {
self.create_then_select_channel(eid).await
}
}
async fn create_then_select_channel(&mut self, eid: ExpertIdRef<'_>) -> EKResult<ChannelMeta> {
let nodes = self.reader.node_by_expert(eid).await?;
for node in nodes {
let addr = node.config["addr"].as_str().unwrap().to_owned();
let channel = node.config["channel"].as_str().unwrap().to_owned();
match channel.as_str() {
"grpc" => {
let end = Channel::from_shared(addr)
.map_err(|e| EKError::InvalidInput(format!("invalid url for gRPC: {e}")))?;
let channel = end.connect().await?;
let meta = GrpcChannelMeta {
ch: channel,
host_id: node.hostname.clone(),
};
self.eid2channels
.entry(eid.to_owned())
.or_default()
.push(ChannelMeta::Grpc(meta));
}
"shm" => {
let shm_channel = self
.all_shm_channels
.entry(node.hostname.clone())
.or_insert_with(|| {
let req_queue = Arc::new(Mutex::new(ShmQueue::new(
&format!("ek-shmq-req-{}", node.hostname),
128,
)));
let resp_queue = Arc::new(Mutex::new(ShmQueue::new(
&format!("ek-shmq-resp-{}", node.hostname),
128,
)));
(req_queue, resp_queue)
});
let meta = ShmChannelMeta {
ch: shm_channel.clone(),
host_id: node.hostname.clone(),
};
self.eid2channels
.entry(eid.to_owned())
.or_default()
.push(ChannelMeta::Shm(meta));
}
"rdma" => {
self.setup_rdma_channel(&node, eid).await?;
}
_ => {
return Err(EKError::NotFound(format!(
"unknown channel type {channel} for expert {eid}"
)));
}
}
}
let res = self.eid2channels.get(eid).ok_or(EKError::NotFound(format!(
"no channel found for expert {eid}"
)))?;
if res.is_empty() {
return Err(EKError::NotFound(format!(
"no channel found for expert {eid}"
)));
}
let idx = rand::random::<usize>() % res.len();
Ok(res[idx].clone())
}
async fn setup_rdma_channel(
&mut self,
node: &crate::state::models::Node,
eid: &str,
) -> EKResult<()> {
log::debug!(
"registering RDMA channel for expert {eid} on node {}",
node.hostname
);
let worker_tcp_port = node
.config
.get("rdma_tcp_port")
.and_then(|v| v.as_u64())
.map(|p| p as u16);
if worker_tcp_port.is_none() {
log::warn!("Missing worker RDMA TCP port for node {}", node.hostname);
return Ok(());
}
let tcp_port = worker_tcp_port.unwrap();
let worker_addr = node.config["addr"].as_str().unwrap();
let worker_ip = self.extract_ip_from_addr(worker_addr)?;
let needs_new_connection = !self.all_rdma_connections.contains_key(&node.hostname);
if needs_new_connection {
let req_queue = RdmaQueue::<ShmqWorkerReq>::new(None, 256, true).map_err(|e| {
EKError::IoError(std::io::Error::other(format!(
"Failed to create controller RDMA request queue: {e}"
)))
})?;
let resp_queue = RdmaQueue::<ShmqWorkerResp>::new(None, 256, false).map_err(|e| {
EKError::IoError(std::io::Error::other(format!(
"Failed to create controller RDMA response queue: {e}"
)))
})?;
let req_queue_arc = Arc::new(Mutex::new(req_queue));
let resp_queue_arc = Arc::new(Mutex::new(resp_queue));
let connection = RdmaNodeConnection {
req_queue: req_queue_arc.clone(),
resp_queue: resp_queue_arc.clone(),
connected: false,
};
self.all_rdma_connections
.insert(node.hostname.clone(), connection);
self.connect_to_worker_via_tcp_with_retry(&node.hostname, &worker_ip, tcp_port)
.await?;
}
let connection = self.all_rdma_connections.get(&node.hostname).unwrap();
let rdma_channel = (connection.req_queue.clone(), connection.resp_queue.clone());
let meta = RdmaChannelMeta {
ch: rdma_channel,
host_id: node.hostname.clone(),
};
self.eid2channels
.entry(eid.to_owned())
.or_default()
.push(ChannelMeta::Rdma(meta));
Ok(())
}
async fn connect_to_worker_via_tcp(
&mut self,
hostname: &str,
worker_ip: &str,
tcp_port: u16,
) -> EKResult<()> {
if let Some(connection) = self.all_rdma_connections.get_mut(hostname) {
if connection.connected {
return Ok(());
}
log::info!(
"Establishing RDMA connection to worker {} (IP: {}) via TCP port {}",
hostname,
worker_ip,
tcp_port
);
let req_queue = connection.req_queue.clone();
let resp_queue = connection.resp_queue.clone();
RdmaEndpointClient::connect_and_exchange(worker_ip, tcp_port, req_queue, resp_queue)
.await
.map_err(|e| EKError::IoError(e))?;
connection.connected = true;
log::info!(
"🚀RDMA connection to worker {} established successfully",
hostname
);
}
Ok(())
}
async fn connect_to_worker_via_tcp_with_retry(
&mut self,
hostname: &str,
worker_ip: &str,
tcp_port: u16,
) -> EKResult<()> {
let mut last_error = None;
for attempt in 0..MAX_RETRIES {
match self
.connect_to_worker_via_tcp(hostname, worker_ip, tcp_port)
.await
{
Ok(()) => return Ok(()),
Err(e) => {
last_error = Some(e);
if attempt < MAX_RETRIES - 1 {
let delay_ms = BASE_DELAY_MS * 2_u64.pow(attempt);
log::warn!(
"RDMA connection attempt {} failed for worker {}, retrying in {}ms: {:?}",
attempt + 1,
hostname,
delay_ms,
last_error.as_ref().unwrap()
);
tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
}
}
}
}
self.all_rdma_connections.remove(hostname);
Err(last_error
.unwrap_or_else(|| EKError::IoError(std::io::Error::other("Unknown connection error"))))
}
fn extract_ip_from_addr(&self, addr: &str) -> EKResult<String> {
let url = Url::parse(addr)
.map_err(|e| EKError::InvalidInput(format!("Invalid URL format: {e}")))?;
let host = url
.host_str()
.ok_or_else(|| EKError::InvalidInput("No host found in URL".to_string()))?;
Ok(host.to_string())
}
pub async fn reset_connection(&mut self, hostname: &str) {
if let Some(connection) = self.all_rdma_connections.get_mut(hostname) {
connection.connected = false;
log::info!("Reset RDMA connection state for worker {}", hostname);
}
}
pub async fn inner_deregister(&mut self, host_id: &str) {
log::info!("deregister host_id {host_id}");
for (_, channels) in self.eid2channels.iter_mut() {
channels.retain(|meta| match meta {
ChannelMeta::Grpc(meta) => meta.host_id != host_id,
ChannelMeta::Shm(meta) => meta.host_id != host_id,
ChannelMeta::Rdma(meta) => meta.host_id != host_id,
});
}
self.all_shm_channels
.retain(|hostname, _| hostname != host_id);
self.reset_connection(host_id).await;
self.all_rdma_connections
.retain(|hostname, _| hostname != host_id);
log::info!("Deregistered worker: {}", host_id);
}
}
impl Default for ExpertRegistryImpl {
fn default() -> Self {
Self::new()
}
}
impl ExpertRegistryImpl {
pub fn new() -> Self {
Self {
eid2channels: HashMap::new(),
all_shm_channels: HashMap::new(),
all_rdma_connections: HashMap::new(),
reader: Box::new(StateReaderImpl::new()),
}
}
}
const MAX_TENSOR_SIZE: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct ShmqWorkerReq {
id: usize,
#[serde(with = "serde_arrays")]
expert_id: [u8; 64],
input_tensor: Vec<u8>,
}
impl ShmqWorkerReq {
pub fn new(expert_id: ExpertIdRef<'_>, input_tensor: &[u8]) -> Self {
static ID: AtomicUsize = AtomicUsize::new(1);
assert!(expert_id.len() < 64);
assert!(input_tensor.len() <= MAX_TENSOR_SIZE);
let mut expert_id_array = [0u8; 64];
let expert_id_bytes = expert_id.as_bytes();
let copy_len = std::cmp::min(expert_id_bytes.len(), 63);
expert_id_array[..copy_len].copy_from_slice(&expert_id_bytes[..copy_len]);
Self {
id: ID.fetch_add(1, Ordering::SeqCst),
expert_id: expert_id_array,
input_tensor: input_tensor.to_vec(),
}
}
pub fn id(&self) -> usize {
self.id
}
pub fn expert_id(&self) -> ExpertId {
let end = self.expert_id.iter().position(|&b| b == 0).unwrap_or(64);
String::from_utf8(self.expert_id[..end].to_vec()).unwrap()
}
pub fn input_tensor(&self) -> &[u8] {
&self.input_tensor
}
}
impl GeneralShmQueueBytes for ShmqWorkerReq {
const CAPACITY: usize =
std::mem::size_of::<usize>() + 64 + std::mem::size_of::<usize>() + MAX_TENSOR_SIZE;
fn write_to_slice(&self, slice: &mut [u8]) {
let mut offset = 0;
slice[offset..offset + 8].copy_from_slice(&self.id.to_le_bytes());
offset += 8;
slice[offset..offset + 64].copy_from_slice(&self.expert_id);
offset += 64;
slice[offset..offset + 8].copy_from_slice(&self.input_tensor.len().to_le_bytes());
offset += 8;
let tensor_len = self.input_tensor.len();
slice[offset..offset + tensor_len].copy_from_slice(&self.input_tensor);
}
fn from_bytes(bytes: &[u8]) -> Self {
let id = usize::from_le_bytes(bytes[..std::mem::size_of::<usize>()].try_into().unwrap());
let expert_id = bytes[std::mem::size_of::<usize>()..std::mem::size_of::<usize>() + 64]
.try_into()
.unwrap();
let input_tensor_len = usize::from_le_bytes(
bytes[std::mem::size_of::<usize>() + 64
..std::mem::size_of::<usize>() + 64 + std::mem::size_of::<usize>()]
.try_into()
.unwrap(),
);
let input_tensor = bytes
[std::mem::size_of::<usize>() + 64 + std::mem::size_of::<usize>()..]
[..input_tensor_len]
.to_vec();
Self {
id,
expert_id,
input_tensor,
}
}
fn len(&self) -> usize {
std::mem::size_of::<usize>() + 64 + std::mem::size_of::<usize>() + self.input_tensor.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShmqWorkerResp {
id: usize,
output_tensor: Vec<u8>,
}
impl ShmqWorkerResp {
pub fn new(id: usize, output_tensor: Vec<u8>) -> Self {
assert!(output_tensor.len() <= MAX_TENSOR_SIZE);
Self { id, output_tensor }
}
pub fn id(&self) -> usize {
self.id
}
pub fn output_tensor(&self) -> &[u8] {
&self.output_tensor
}
}
impl GeneralShmQueueBytes for ShmqWorkerResp {
const CAPACITY: usize =
std::mem::size_of::<usize>() + std::mem::size_of::<usize>() + MAX_TENSOR_SIZE;
fn write_to_slice(&self, slice: &mut [u8]) {
let mut offset = 0;
slice[offset..offset + 8].copy_from_slice(&self.id.to_le_bytes());
offset += 8;
slice[offset..offset + 8].copy_from_slice(&self.output_tensor.len().to_le_bytes());
offset += 8;
let tensor_len = self.output_tensor.len();
slice[offset..offset + tensor_len].copy_from_slice(&self.output_tensor);
}
fn from_bytes(bytes: &[u8]) -> Self {
let id = usize::from_le_bytes(bytes[..std::mem::size_of::<usize>()].try_into().unwrap());
let output_tensor_len = usize::from_le_bytes(
bytes[std::mem::size_of::<usize>()
..std::mem::size_of::<usize>() + std::mem::size_of::<usize>()]
.try_into()
.unwrap(),
);
let output_tensor = bytes[std::mem::size_of::<usize>() + std::mem::size_of::<usize>()..]
[..output_tensor_len]
.to_vec();
Self { id, output_tensor }
}
fn len(&self) -> usize {
std::mem::size_of::<usize>() + std::mem::size_of::<usize>() + self.output_tensor.len()
}
}
pub type GlobalWorkerRegistry = Arc<Mutex<dyn ExpertRegistry + Send + Sync>>;
pub fn get_registry() -> GlobalWorkerRegistry {
static INSTANCE: OnceLock<Arc<Mutex<ExpertRegistryImpl>>> = OnceLock::new();
let res = INSTANCE.get_or_init(|| {
let inner = ExpertRegistryImpl::new();
Arc::new(Mutex::new(inner))
});
(res.clone()) as _
}