use std::time::Instant;
use crate::{
metrics::{METRIC_WORKER_EXPERT_ACTIVATION, METRIC_WORKER_FORWARD},
proto::ek::worker::v1::{
ForwardReq, ForwardResp, computation_service_server::ComputationService,
},
};
use ek_base::utils::Defers;
use tonic::{Request, Response, Status};
use tracing::instrument;
use super::core::{EKInstanceGateSync, get_instance_gate_sync};
use tracing_opentelemetry::OpenTelemetrySpanExt;
#[derive(Debug)]
pub struct BasicExpertImpl {
gate_sync: &'static EKInstanceGateSync,
}
impl Default for BasicExpertImpl {
fn default() -> Self {
Self::new()
}
}
impl BasicExpertImpl {
pub fn new() -> Self {
Self {
gate_sync: get_instance_gate_sync(),
}
}
}
#[tonic::async_trait]
impl ComputationService for BasicExpertImpl {
#[instrument(skip(self, request))]
async fn forward(&self, request: Request<ForwardReq>) -> Result<Response<ForwardResp>, Status> {
let now = Instant::now();
let exp_id = request.get_ref().sequences[0].experts[0].clone();
tracing::debug!("[L1 {:?}] exp received!", &exp_id);
let res = self.inner_forward(request).await;
tracing::debug!("[L1 {:?}] completed in {:?}", &exp_id, now.elapsed());
res
}
}
impl BasicExpertImpl {
#[inline]
async fn inner_forward(
&self,
request: Request<ForwardReq>,
) -> Result<Response<ForwardResp>, Status> {
tracing::debug!(
"[L2 {:?}] rpc.forward() request: seq={}",
request.get_ref().sequences[0].experts[0],
request.get_ref().sequences.len(),
);
let exp_id = request.get_ref().sequences[0].experts[0].clone();
let start = Instant::now();
let start_cloned = start;
let settings = ek_base::config::get_ek_settings();
METRIC_WORKER_EXPERT_ACTIVATION
.with_label_values(&[
settings.worker.id.as_str(),
settings.inference.model_name.as_str(),
request.get_ref().sequences[0].experts[0].as_str(),
])
.inc_by(request.get_ref().sequences.len() as u64);
{
let worker_id = settings.worker.id.as_str();
let model = settings.inference.model_name.as_str();
let expert = request.get_ref().sequences[0].experts[0].as_str();
let count = request.get_ref().sequences.len();
log::info!(
worker_id:%,
model:%,
expert:%,
count:%
; "expert activation record",
);
}
Defers::defer(Box::new(move || {
let elapsed = start_cloned.elapsed();
METRIC_WORKER_FORWARD
.with_label_values(&[
settings.worker.id.as_str(),
settings.inference.model_name.as_str(),
])
.observe(elapsed.as_micros() as f64);
}));
tracing::debug!("[L2 {:?}] sync_gate.forward() start", &exp_id,);
let forward_now = Instant::now();
let req_inner = request.into_inner();
let gate_sync = self.gate_sync;
let cx = tracing::Span::current().context();
let cx_clone = cx.clone();
let res = tokio::task::spawn_blocking(move || {
let _guard = cx_clone.attach();
gate_sync.forward_sync(req_inner)
})
.await
.map_err(|e| {
log::error!("blocking task join error {e:?}");
Status::internal("blocking task error")
})?
.map_err(|e| {
log::error!("forward error {e:?}");
Status::internal("forward error")
})?;
tracing::debug!(
"[L2 {:?}] sync_gate.forward() end, elapsed {:?}",
&exp_id,
forward_now.elapsed(),
);
let res = Ok(Response::new(res));
tracing::debug!(
"[L2 {:?}] rpc.forward() end with {:?}",
&exp_id,
start.elapsed(),
);
res
}
}