use std::sync::Arc;
use std::time::{Duration, Instant};
use clap::Parser;
use log::{error, info};
use safetensors::tensor::{Dtype, TensorView};
use tokio::sync::Mutex;
use tokio::time::sleep;
use tonic::{Request, Response, Status, transport::Server};
use ek_computation::proto::ek::worker::v1::computation_service_server::{
ComputationService, ComputationServiceServer,
};
use ek_computation::proto::ek::worker::v1::{ForwardReq, ForwardResp};
#[derive(Debug, Clone)]
struct Config {
expert_dim: usize,
latency_ms: u64,
num_experts: usize,
}
#[derive(Debug, Default)]
struct ServerStats {
request_count: u64,
total_tokens_processed: u64,
total_unique_tokens_processed: u64,
total_processing_time_ms: u64,
}
struct ExpertKitService {
config: Config,
stats: Arc<Mutex<ServerStats>>,
}
impl ExpertKitService {
fn new(config: Config) -> Self {
Self {
config,
stats: Arc::new(Mutex::new(ServerStats::default())),
}
}
}
#[tonic::async_trait]
impl ComputationService for ExpertKitService {
async fn forward(&self, request: Request<ForwardReq>) -> Result<Response<ForwardResp>, Status> {
let start_time = Instant::now();
let request = request.into_inner();
let mut stats = self.stats.lock().await;
stats.request_count += 1;
let num_sequences = request.sequences.len();
stats.total_tokens_processed += num_sequences as u64;
info!("Received request for instance {}", request.instance_id);
info!("Processing {} sequences", num_sequences);
if self.config.latency_ms > 0 {
sleep(Duration::from_millis(self.config.latency_ms)).await;
}
match self.process_request(request, &mut stats).await {
Ok(resp) => {
let processing_time = start_time.elapsed();
stats.total_processing_time_ms += processing_time.as_millis() as u64;
info!("Request processed in {processing_time:.2?}");
let avg_time_per_token =
stats.total_processing_time_ms as f64 / stats.total_tokens_processed as f64;
info!("Average processing time: {avg_time_per_token:.2}ms per token");
if stats.total_tokens_processed > 0 {
let unique_ratio = stats.total_unique_tokens_processed as f64
/ stats.total_tokens_processed as f64;
info!(
"Duplication ratio: {:.2}% unique tokens overall",
unique_ratio * 100.0
);
}
Ok(Response::new(resp))
}
Err(e) => {
error!("Error processing request: {e}");
Err(e)
}
}
}
}
impl ExpertKitService {
async fn process_request(
&self,
request: ForwardReq,
stats: &mut ServerStats,
) -> Result<ForwardResp, Status> {
let tensors = match safetensors::SafeTensors::deserialize(&request.tensor) {
Ok(tensors) => tensors,
Err(e) => {
return Err(Status::internal(format!(
"Failed to deserialize tensor: {e}"
)));
}
};
let input_tensor = match tensors.tensor("data") {
Ok(tensor) => tensor,
Err(e) => {
return Err(Status::internal(format!(
"Failed to get 'data' tensor: {e}"
)));
}
};
let shapes = input_tensor.shape();
if shapes.len() != 2 {
return Err(Status::invalid_argument(format!(
"Expected 2D tensor, got {}D",
shapes.len()
)));
}
let batch_size = shapes[0] as usize;
let hidden_dim = shapes[1] as usize;
if batch_size != request.sequences.len() {
return Err(Status::invalid_argument(format!(
"Batch size mismatch: tensor has {} but sequences has {}",
batch_size,
request.sequences.len()
)));
}
info!("Input tensor shape: [{batch_size}, {hidden_dim}]");
stats.total_unique_tokens_processed += batch_size as u64;
let mut num_experts_per_seq = Vec::new();
for seq in &request.sequences {
num_experts_per_seq.push(seq.experts.len());
}
let max_experts = num_experts_per_seq.iter().max().unwrap_or(&0);
let output_shape = vec![batch_size, *max_experts, self.config.expert_dim];
let total_elements = output_shape.iter().product::<usize>();
let output_data: Vec<f32> = vec![0.0; total_elements];
info!(
"Output tensor shape: [{}, {}, {}]",
batch_size, max_experts, self.config.expert_dim
);
let output_bytes = unsafe {
std::slice::from_raw_parts(
output_data.as_ptr() as *const u8,
output_data.len() * std::mem::size_of::<f32>(),
)
};
let tensor_view = TensorView::new(Dtype::F32, output_shape, output_bytes)
.map_err(|e| Status::internal(format!("Failed to create tensor view: {e}")))?;
let tensor_map = std::collections::HashMap::from([("data".to_string(), tensor_view)]);
let buffer = safetensors::serialize(&tensor_map, &None)
.map_err(|e| Status::internal(format!("Failed to serialize tensor: {e}")))?;
Ok(ForwardResp {
output_tensor: buffer,
})
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct CliArgs {
#[arg(short, long, default_value_t = 50051)]
port: u16,
#[arg(long, default_value_t = 7168)]
expert_dim: usize,
#[arg(long, default_value_t = 0)]
latency_ms: u64,
#[arg(short, long, default_value_t = 8)]
num_experts: usize,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("info"));
let args = CliArgs::parse();
let config = Config {
expert_dim: args.expert_dim,
latency_ms: args.latency_ms,
num_experts: args.num_experts,
};
info!(
"Configuration: expert_dim={}, latency_ms={}, num_experts={}",
config.expert_dim, config.latency_ms, config.num_experts
);
let expert_service = ExpertKitService::new(config.clone());
let expert_service_server = ComputationServiceServer::new(expert_service);
let addr = format!("0.0.0.0:{}", args.port).parse()?;
info!("Starting server on {addr}");
Server::builder()
.add_service(expert_service_server)
.serve(addr)
.await?;
Ok(())
}