use super::manager::{ExpertDB, ExpertDBSync, get_expert_db, get_expert_db_sync};
use crate::{controller::registry::ExpertIdRef, proto::ek};
use core::fmt;
use ek_base::error::EKResult;
use std::sync::{Arc, OnceLock};
use tokio;
use tracing::instrument;
pub struct EKInstanceGateAsync {
experts: Arc<tokio::sync::RwLock<dyn ExpertDB + Send + Sync>>,
}
impl fmt::Debug for EKInstanceGateAsync {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EKInstanceGateAsync").finish()
}
}
impl Default for EKInstanceGateAsync {
fn default() -> Self {
let edb = get_expert_db();
EKInstanceGateAsync { experts: edb }
}
}
pub struct EKInstanceGateSync {
experts: Arc<dyn ExpertDBSync + Send + Sync>,
}
impl fmt::Debug for EKInstanceGateSync {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EKInstanceGateSync").finish()
}
}
impl Default for EKInstanceGateSync {
fn default() -> Self {
let edb = get_expert_db_sync();
EKInstanceGateSync { experts: edb }
}
}
pub fn get_instance_gate() -> &'static EKInstanceGateAsync {
static INSTANCE: OnceLock<EKInstanceGateAsync> = OnceLock::new();
INSTANCE.get_or_init(EKInstanceGateAsync::new)
}
pub fn get_instance_gate_sync() -> &'static EKInstanceGateSync {
static INSTANCE: OnceLock<EKInstanceGateSync> = OnceLock::new();
INSTANCE.get_or_init(EKInstanceGateSync::new)
}
impl EKInstanceGateAsync {
pub fn new() -> Self {
let edb = get_expert_db();
EKInstanceGateAsync { experts: edb }
}
pub async fn current_experts(&self) -> EKResult<Vec<String>> {
self.experts.read().await.keys().await
}
}
impl EKInstanceGateSync {
pub fn new() -> Self {
let edb = get_expert_db_sync();
EKInstanceGateSync { experts: edb }
}
#[instrument(skip(self, req))]
pub fn forward_sync(
&self,
req: ek::worker::v1::ForwardReq,
) -> EKResult<ek::worker::v1::ForwardResp> {
tracing::debug!(
"[L3 {:?}] gate.forward_sync() start: seq={}",
req.sequences[0].experts[0],
req.sequences.len(),
);
let start = std::time::Instant::now();
let input_tensor = req.tensor;
assert!(!req.sequences.is_empty());
assert!(req.sequences[0].experts.len() == 1);
let exp_id = &req.sequences[0].experts[0];
let exp = self.experts.load(exp_id)?;
let now = std::time::Instant::now();
tracing::debug!("[L3 {:?}] exp_backend.forward_sync() started", exp_id,);
let st = safetensors::SafeTensors::deserialize(&input_tensor).unwrap();
let tv = st.tensor("data")?;
let res = exp.forward(&tv)?;
tracing::debug!(
"[L3 {:?}] exp_backend.forward_sync() completed in {:?}",
exp_id,
now.elapsed()
);
let output_bytes = res;
tracing::debug!("output bytes_len={}", output_bytes.len());
let resp = ek::worker::v1::ForwardResp {
output_tensor: output_bytes,
};
tracing::debug!(
"[L3 {:?}] gate.forward_sync() end with {:?}",
&exp_id,
start.elapsed(),
);
Ok(resp)
}
pub fn forward_sync_core(
&self,
expert_id: ExpertIdRef<'_>,
input_tensor: &[u8],
) -> EKResult<Vec<u8>> {
let exp = self.experts.load(expert_id)?;
let st = safetensors::SafeTensors::deserialize(input_tensor)?;
let tv = st.tensor("data")?;
let res = exp.forward(&tv)?;
Ok(res)
}
#[expect(unused)]
pub fn current_experts_sync(&self) -> EKResult<Vec<String>> {
self.experts.keys()
}
}