use std::{
collections::{BTreeMap, HashMap},
fmt,
sync::{Arc, OnceLock},
time,
};
use ek_base::{
config::get_ek_settings,
error::{EKError, EKResult},
utils::{Defers, PerfTimer},
};
use safetensors::SafeTensors;
use tch::{IndexOp, Tensor};
use tokio::sync::{Mutex, mpsc};
use tracing::{Instrument, instrument, span};
use crate::{
backend::{EkTensor, torch::TchTensor},
controller::registry::{ExpertClient, ExpertId, ExpertIdRef, ShmqWorkerReq, ShmqWorkerResp},
metrics::METRIC_CONTROLLER_INTRA_REQ,
proto::ek::worker::v1::{self},
};
use super::registry::{GlobalWorkerRegistry, get_registry};
#[async_trait::async_trait]
pub trait Executor {
async fn submit(
&mut self,
req: &v1::ForwardReq,
) -> EKResult<mpsc::Receiver<Arc<v1::ForwardResp>>>;
async fn exec(&mut self) -> EKResult<()>;
}
type ReqId = u64;
type GlobalSeqId = u64;
type LocalSeqIdx = usize;
struct IngressMeta {
tensor: Tensor,
sender: mpsc::Sender<Arc<v1::ForwardResp>>,
result: Vec<Vec<Option<Tensor>>>,
}
unsafe impl Sync for IngressMeta {}
#[derive(Clone)]
struct EgressMeta {
req_id: ReqId,
seq_gid: GlobalSeqId,
expert_idx: usize,
}
enum ForwardResponse {
Grpc(v1::ForwardResp),
Shm(ShmqWorkerResp),
Rdma(ShmqWorkerResp),
}
impl ForwardResponse {
fn output_tensor(&self) -> &[u8] {
match self {
ForwardResponse::Grpc(resp) => &resp.output_tensor,
ForwardResponse::Shm(resp) => resp.output_tensor(),
ForwardResponse::Rdma(resp) => resp.output_tensor(),
}
}
}
#[derive(Debug, Clone)]
enum PendingResponse {
Shm(ShmqWorkerResp),
Rdma(ShmqWorkerResp),
}
impl PendingResponse {
fn into_shm(self) -> Option<ShmqWorkerResp> {
match self {
PendingResponse::Shm(resp) => Some(resp),
PendingResponse::Rdma(_) => None,
}
}
fn into_rdma(self) -> Option<ShmqWorkerResp> {
match self {
PendingResponse::Shm(_) => None,
PendingResponse::Rdma(resp) => Some(resp),
}
}
}
pub struct NaiveExecutor {
pending_egress: BTreeMap<ExpertId, Vec<EgressMeta>>,
pending_ingress: BTreeMap<ReqId, IngressMeta>,
pending_resp: Arc<Mutex<HashMap<usize, PendingResponse>>>,
seq_mapping: BTreeMap<GlobalSeqId, (ReqId, LocalSeqIdx)>,
seq_gid_cursor: u64,
req_id_cursor: u64,
registry: GlobalWorkerRegistry,
}
impl fmt::Debug for NaiveExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NaiveExecutor").finish()
}
}
#[async_trait::async_trait]
impl Executor for NaiveExecutor {
async fn submit(
&mut self,
req: &v1::ForwardReq,
) -> EKResult<mpsc::Receiver<Arc<v1::ForwardResp>>> {
self.inner_submit(req).await
}
async fn exec(&mut self) -> EKResult<()> {
self.inner_execute().await
}
}
impl NaiveExecutor {
async fn inner_submit(
&mut self,
req: &v1::ForwardReq,
) -> EKResult<mpsc::Receiver<Arc<v1::ForwardResp>>> {
let (sender, receiver) = mpsc::channel(1);
log::debug!("submit request, seq_len {:?}", req.sequences.len());
let span = span!(
tracing::Level::INFO,
"naive_executor_submit",
seq_len = req.sequences.len(),
instance_id = req.instance_id.as_str()
);
let _enter = span.enter();
let inp_safetensor = SafeTensors::deserialize(&req.tensor)?;
let inp_view = inp_safetensor.tensor("data")?;
let inp_tensor = TchTensor::from(&inp_view);
let mut result = vec![];
for i in &req.sequences {
let mut experts = Vec::new();
for _ in &i.experts {
experts.push(None);
}
result.push(experts);
}
let meta = IngressMeta {
tensor: inp_tensor.inner(),
sender,
result,
};
self.req_id_cursor += 1;
self.pending_ingress.insert(self.req_id_cursor, meta);
self.break_down_to_egress(req, self.req_id_cursor);
Ok(receiver)
}
fn assemble_seq_tensors(&self, gids: Vec<GlobalSeqId>) -> EKResult<Tensor> {
let mut tensors = vec![];
for gid in gids {
let (rid, lid) = self
.seq_mapping
.get(&gid)
.ok_or(EKError::NotFound("seq not found".into()))?;
let ingress_meta = self
.pending_ingress
.get(rid)
.ok_or(EKError::NotFound("req tensor not found".into()))?;
let hidden = ingress_meta.tensor.i(*lid as i64);
tensors.push(hidden);
}
let out = Tensor::stack(&tensors, 0);
log::debug!(
"assemble seq tensor, vec_len={} shape={:?}",
tensors.len(),
out.size()
);
Ok(out)
}
#[instrument]
pub async fn inner_execute(&mut self) -> EKResult<()> {
let mut tit = PerfTimer::new("inner_execute");
let mut handles = vec![];
let mut chips: Vec<(ExpertId, Vec<EgressMeta>)> = vec![];
let settings = get_ek_settings();
while let Some((expert_id, egress_meta)) = self.pending_egress.pop_first() {
let expert_id: ExpertIdRef = expert_id.as_ref();
let Ok(client) = self.registry.lock().await.select(expert_id).await else {
log::warn!("failed to select client for expert {expert_id}");
continue;
};
chips.push((expert_id.to_owned(), egress_meta.to_owned()));
let seq_gids = egress_meta
.iter()
.map(|e| e.seq_gid)
.collect::<Vec<GlobalSeqId>>();
let egress_tensor = self.assemble_seq_tensors(seq_gids)?;
log::debug!("egress tensor shape={:?}", egress_tensor.size());
let serialized_tensor = TchTensor::from(egress_tensor).serialize();
let seqs = egress_meta
.iter()
.map(|_e| v1::forward_req::SequenceInfo {
experts: vec![expert_id.to_owned()],
})
.collect::<Vec<_>>();
let pending_resp = self.pending_resp.clone();
let expert_id = expert_id.to_owned();
match client {
ExpertClient::Grpc(grpc_channel) => {
let mut cli =
v1::computation_service_client::ComputationServiceClient::new(grpc_channel)
.max_decoding_message_size(1024 * 1024 * 1024)
.max_encoding_message_size(1024 * 1024 * 1024);
let f = tokio::spawn(
async move {
let req = v1::ForwardReq {
instance_id: "0".into(),
tensor: serialized_tensor,
sequences: seqs,
};
let start = time::Instant::now();
let _d = Defers::defer(Box::new(move || {
let elapsed = start.elapsed();
METRIC_CONTROLLER_INTRA_REQ
.with_label_values(&[settings.inference.model_name.as_str()])
.observe(elapsed.as_micros() as f64);
}));
cli.forward(req)
.await
.map(|resp| ForwardResponse::Grpc(resp.into_inner()))
.map_err(|e| log::error!("forward error: {e}"))
}
.in_current_span(),
);
handles.push(f);
}
ExpertClient::Shm((send_channel, recv_channel)) => {
let fu = async move {
let req = ShmqWorkerReq::new(expert_id.as_ref(), &serialized_tensor);
let start = time::Instant::now();
let _d = Defers::defer(Box::new(move || {
let elapsed = start.elapsed();
METRIC_CONTROLLER_INTRA_REQ
.with_label_values(&[settings.inference.model_name.as_str()])
.observe(elapsed.as_micros() as f64);
}));
while send_channel.lock().await.send(&req).is_err() {
log::warn!("failed to send request to expert {expert_id}");
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
log::debug!(
"request sent for expert {}, waiting for response",
expert_id
);
let resp = loop {
if let Some(pending_resp) = pending_resp.lock().await.remove(&req.id())
&& let Some(resp) = pending_resp.into_shm()
{
break resp;
}
match recv_channel.lock().await.recv() {
Ok(resp) => {
if resp.id() != req.id() {
log::debug!(
"received response for expert {} but id not correct",
expert_id
);
pending_resp
.lock()
.await
.insert(resp.id(), PendingResponse::Shm(resp));
tokio::task::yield_now().await;
continue;
}
break resp;
}
Err(_) => {
tokio::task::yield_now().await;
}
}
};
Ok(ForwardResponse::Shm(resp))
}
.in_current_span();
handles.push(tokio::spawn(fu));
}
ExpertClient::Rdma((send_channel, recv_channel)) => {
let fu = async move {
let req = ShmqWorkerReq::new(expert_id.as_ref(), &serialized_tensor);
let start = time::Instant::now();
let _d = Defers::defer(Box::new(move || {
let elapsed = start.elapsed();
METRIC_CONTROLLER_INTRA_REQ
.with_label_values(&[settings.inference.model_name.as_str()])
.observe(elapsed.as_micros() as f64);
}));
loop {
match send_channel.lock().await.send(&req) {
Ok(_) => break,
Err(e) => {
log::warn!("failed to send RDMA request to expert {expert_id}: {e}");
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
}
}
log::debug!(
"RDMA request sent for expert {}, waiting for response",
expert_id
);
let resp = loop {
if let Some(pending_resp) = pending_resp.lock().await.remove(&req.id())
&& let Some(resp) = pending_resp.into_rdma() {
break resp;
}
match recv_channel.lock().await.recv() {
Ok(resp) => {
if resp.id() != req.id() {
log::debug!(
"received RDMA response for expert {} but id not correct",
expert_id
);
pending_resp.lock().await.insert(resp.id(), PendingResponse::Rdma(resp));
tokio::task::yield_now().await;
continue;
}
break resp;
}
Err(_) => {
tokio::task::yield_now().await;
}
}
};
Ok(ForwardResponse::Rdma(resp))
}
.in_current_span();
handles.push(tokio::spawn(fu));
}
}
}
tit.stop("egress_req_sent");
for (egress_idx, handle) in handles.into_iter().enumerate() {
let egress = &chips[egress_idx];
let Ok(res) = handle.await? else {
log::error!("failed to receive response for expert {}", egress.0);
continue;
};
let res_safetensor = SafeTensors::deserialize(res.output_tensor())?;
let view = res_safetensor.tensor("data")?;
let res_tensor = TchTensor::from(&view).inner();
log::debug!("received tensor shape={:?}", res_tensor.size());
for (seq_idx, egress_meta) in egress.1.iter().enumerate() {
let id_mapping = self
.seq_mapping
.get(&egress_meta.seq_gid)
.ok_or(EKError::NotFound("no seq mapping".into()))?;
assert!(id_mapping.0 == egress_meta.req_id);
let lid = id_mapping.1;
let meta = self
.pending_ingress
.get_mut(&egress_meta.req_id)
.ok_or(EKError::NotFound("no ingress req found".into()))?;
let seq_completion = &mut meta.result[lid];
seq_completion[egress_meta.expert_idx] = Some(res_tensor.i(seq_idx as i64));
}
}
tit.stop("remote resp joined");
self.output().await;
tit.stop("output generated");
Ok(())
}
async fn output(&mut self) {
let mut removed = vec![];
for (req_id, meta) in self.pending_ingress.iter() {
let completed = meta.result.iter().all(|x| x.iter().all(|v| v.is_some()));
if !completed {
continue;
}
let res_tensors = meta
.result
.iter()
.map(|x| {
let must_tensor = x.iter().map(|x| x.as_ref().unwrap()).collect::<Vec<_>>();
Tensor::stack(&must_tensor, 0)
})
.collect::<Vec<_>>();
let output_tensor = Tensor::stack(&res_tensors, 0);
log::debug!("output tensor shape: {:?}", output_tensor.size());
let serialized_tensor = TchTensor::from(output_tensor).serialize();
let resp = v1::ForwardResp {
output_tensor: serialized_tensor,
};
let send_res = meta
.sender
.send_timeout(Arc::new(resp), std::time::Duration::from_secs(5))
.await;
if let Err(e) = send_res {
log::error!("send forward response error: {e}");
}
removed.push(*req_id);
}
for rid in removed {
self.pending_ingress.remove(&rid);
let gids_to_remove = self
.seq_mapping
.iter()
.filter(|x| x.1.0 == rid)
.map(|x| *x.0)
.collect::<Vec<_>>();
for key in gids_to_remove {
self.seq_mapping.remove(&key);
}
}
}
#[instrument(skip(self, req))]
fn break_down_to_egress(&mut self, req: &v1::ForwardReq, req_id: ReqId) {
for (idx, seq) in req.sequences.iter().enumerate() {
let seq_gid = self.add_seq(req_id, idx as LocalSeqIdx);
for (idx, expert) in seq.experts.iter().enumerate() {
self.pending_egress
.entry(expert.clone())
.or_default()
.push(EgressMeta {
req_id,
seq_gid,
expert_idx: idx,
});
}
}
}
fn add_seq(&mut self, rid: ReqId, seq_lid: LocalSeqIdx) -> GlobalSeqId {
self.seq_gid_cursor += 1;
self.seq_mapping.insert(self.seq_gid_cursor, (rid, seq_lid));
self.seq_gid_cursor
}
}
impl Default for NaiveExecutor {
fn default() -> Self {
Self::new()
}
}
impl NaiveExecutor {
pub fn new() -> Self {
Self {
pending_egress: BTreeMap::new(),
pending_ingress: BTreeMap::new(),
seq_mapping: BTreeMap::new(),
seq_gid_cursor: 0,
req_id_cursor: 0,
registry: get_registry(),
pending_resp: Arc::new(Mutex::new(HashMap::new())),
}
}
}
pub fn get_executor() -> Arc<Mutex<dyn Executor + Send>> {
static INSTANCE: OnceLock<Arc<Mutex<dyn Executor + Send>>> = OnceLock::new();
let res = INSTANCE.get_or_init(|| {
let inner = NaiveExecutor::new();
Arc::new(Mutex::new(inner))
});
(res.clone()) as _
}