use std::sync::OnceLock;
use crate::backend::{DType, Device, EkTensor, torch::TchTensor};
use super::{
ExpertWeight,
meta::{Expert, ExpertShape},
};
use ek_base::error::EKResult;
use tch::{
self,
nn::{self, Module},
};
pub struct TorchFFN {
dim: usize,
intermediate_dim: usize,
module: OnceLock<nn::Sequential>,
weight: ExpertWeight<TchTensor>,
device: Device,
}
pub fn w8a16_activate(x: &tch::Tensor, s: &tch::Tensor, block_size: i64) -> tch::Tensor {
let shape = s.size();
let x_shape = x.size();
assert!(shape.len() == 2);
assert!(x_shape.len() == 2);
let m = shape[0];
let n = shape[1];
let pad = x_shape[0] % block_size;
let s = s.reshape([shape[0], shape[1], 1]);
let l = if pad > 0 {
let t = tch::Tensor::zeros([pad, x_shape[1]], (x.kind(), x.device()));
(tch::Tensor::cat(&[x, &t], 0))
.reshape([m, block_size, n, block_size])
.permute([0, 2, 1, 3])
.reshape([m, n, block_size * block_size])
.to_kind(tch::Kind::Float)
} else {
x.reshape([m, block_size, n, block_size])
.permute([0, 2, 1, 3])
.reshape([m, n, block_size * block_size])
.to_kind(tch::Kind::Float)
};
(l * s)
.to_kind(tch::Kind::BFloat16)
.reshape([m, n, block_size, block_size])
.permute([0, 2, 1, 3])
.reshape(x_shape.clone())
}
unsafe impl Sync for TorchFFN {}
impl TorchFFN {
pub fn new(
dim: usize,
intermediate_dim: usize,
module: OnceLock<nn::Sequential>,
weight: ExpertWeight<TchTensor>,
device: Device,
) -> Self {
TorchFFN {
dim,
intermediate_dim,
module,
weight,
device,
}
}
pub fn load_module(&self) -> &nn::Sequential {
self.module.get_or_init(|| {
tch::no_grad(|| {
let w1_tensor = self
.weight
.up_w
.inner()
.shallow_clone()
.to_kind(tch::Kind::BFloat16)
.to_device(self.device.into());
let w2_tensor = self
.weight
.down_w
.inner()
.shallow_clone()
.to_kind(tch::Kind::BFloat16)
.to_device(self.device.into());
let w3_tensor = self
.weight
.gate_w
.inner()
.shallow_clone()
.to_kind(tch::Kind::BFloat16)
.to_device(self.device.into());
nn::seq().add_fn(move |x| {
let _up = x.matmul(&w1_tensor.transpose(0, 1));
let _gate = x.matmul(&w3_tensor.transpose(0, 1));
let _hidden = _up * _gate.silu();
_hidden.matmul(&w2_tensor.transpose(0, 1))
})
})
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Expert<TchTensor> for TorchFFN {
fn forward(&self, x: &TchTensor) -> TchTensor {
let module = self.load_module();
let res = module.forward(&x.inner());
TchTensor(res)
}
fn rand_input(&self, batch: usize) -> TchTensor {
TchTensor::rand(vec![batch, self.dim], DType::BFloat16, Device::CPU)
}
fn shape(&self) -> ExpertShape {
ExpertShape {
hidden: self.dim,
intermediate: self.intermediate_dim,
}
}
fn backend(&self) -> std::string::String {
"torch".to_string()
}
fn construct(x: crate::x::EKInstance, weight: ExpertWeight<TchTensor>) -> EKResult<Self> {
Ok(TorchFFN {
intermediate_dim: x.intermediate,
dim: x.hidden,
module: OnceLock::new(),
weight,
device: x.device,
})
}
}
#[cfg(test)]
mod test {
use std::fs;
use ek_base::utils::workspace_root;
use safetensors::SafeTensors;
use tch::IndexOp;
use crate::{
backend::{Device, EkTensor},
ffn::{Expert, ExpertWeight, expert_torch::TorchFFN},
x::{self, test_root},
};
use super::{TchTensor, w8a16_activate};
#[test]
fn test_io() {
let rand_t = tch::Tensor::randn(vec![128, 128], (tch::Kind::Float, tch::Device::Cpu));
let target = TchTensor::from(rand_t.copy());
let bytes = target.serialize();
let st = SafeTensors::deserialize(&bytes).unwrap();
let tv = st.tensor("data").unwrap();
let processed = TchTensor::from_tensor_view(&tv);
assert!(processed.inner().sum(tch::Kind::Float) == rand_t.sum(tch::Kind::Float))
}
#[test]
fn test_correctness() {
let st_fp = test_root()
.join("resources")
.join("qwen3-l0e1.weight.safetensors");
let st_bytes = fs::read(st_fp).unwrap();
let st = SafeTensors::deserialize(&st_bytes).unwrap();
let weight = ExpertWeight::from_safetensor(&st, Device::CPU).unwrap();
let inst = x::EKInstance {
hidden: 2048,
intermediate: 768,
backend: x::ExpertBackendType::Torch,
device: Device::CPU,
};
let ffn = TorchFFN::construct(inst, weight).unwrap();
let ground_truth_fp = test_root()
.join("resources")
.join("qwen3-l0e1.result.safetensors");
let ground_truth_bytes = fs::read(ground_truth_fp).unwrap();
let gt_st = SafeTensors::deserialize(&ground_truth_bytes).unwrap();
let tv = gt_st.tensor("1-input").unwrap();
let inp = TchTensor::from_tensor_view(&tv);
let res = ffn.forward(&inp).inner();
let truth = TchTensor::from_tensor_view(>_st.tensor("1-output").unwrap()).inner();
let _vec1 = Vec::<f32>::try_from(res.i((0, 0..100))).unwrap();
let _vec2 = Vec::<f32>::try_from(truth.i((0, 0..100))).unwrap();
(res - truth).sum(tch::Kind::BFloat16).print();
}
#[test]
fn test_fp8_dequant() {
let st_fp = workspace_root()
.join("ek-computation")
.join("resources")
.join("w8a16active-l0q_a_proj.safetensors");
let st_bytes = fs::read(st_fp).unwrap();
let st = SafeTensors::deserialize(&st_bytes).unwrap();
let tv1 = st.tensor("src").unwrap();
let tv2 = st.tensor("src_scale").unwrap();
let expected = st.tensor("triton_dequanted").unwrap();
let tv1 = TchTensor::from_tensor_view(&tv1).inner();
let tv2 = TchTensor::from_tensor_view(&tv2).inner();
let expected = TchTensor::from_tensor_view(&expected).inner();
let res = w8a16_activate(&tv1, &tv2, 128);
let diff = (res - expected)
.sum(tch::Kind::Double)
.abs()
.double_value(&[]);
assert!(diff < 0.2);
}
}
#[cfg(test)]
mod bench_ffn_concurrent {
use super::*;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::task::JoinSet;
#[tokio::test]
async fn bench_torch_ffn_concurrent_vs_serial() {
let ffn_count = 64;
let batch_size = 1;
let reqests_num = 512;
let thread_num = 64;
let mut ffns = Vec::new();
for _ in 0..ffn_count {
let ffn = TorchFFN {
dim: 2048,
intermediate_dim: 768,
weight: ExpertWeight::from_rand_linear(
2048,
768,
crate::backend::DType::BFloat16,
crate::backend::Device::CPU,
),
module: OnceLock::new(),
device: Device::CPU,
};
ffns.push(ffn);
}
let ffns: Vec<Arc<TorchFFN>> = ffns.into_iter().map(Arc::new).collect();
println!("🔥 FFN Concurrent vs Serial Performance Comparison Test");
println!(" FFN count: {ffn_count}");
println!(" Batch size: {batch_size}");
println!(" Requests num: {reqests_num}");
println!(" Thread num: {thread_num}");
println!();
warm_up_ffns(&ffns, batch_size).await;
let serial_results = run_serial_benchmark(&ffns, batch_size, reqests_num).await;
let concurrent_results =
run_concurrent_benchmark(&ffns, batch_size, reqests_num, thread_num).await;
compare_results(&serial_results, &concurrent_results, batch_size, thread_num);
}
async fn warm_up_ffns(ffns: &[Arc<TorchFFN>], batch_size: usize) {
println!("🚀 Warming up {} FFNs...", ffns.len());
for ffn in ffns {
let inp = ffn.rand_input(batch_size);
let _ = ffn.forward(&inp);
}
println!("✅ Warmup completed");
println!();
}
async fn run_serial_benchmark(
ffns: &[Arc<TorchFFN>],
batch_size: usize,
requests_num: usize,
) -> Vec<Duration> {
println!("📊 Starting serial test...");
let mut results = Vec::new();
let inp = ffns[0].rand_input(batch_size);
for i in 0..requests_num {
let ffn_idx = rand::random_range(0..ffns.len());
let ffn = &ffns[ffn_idx];
let start = Instant::now();
let r = ffn.forward(&inp);
let elapsed = start.elapsed();
results.push(elapsed);
let _ = std::hint::black_box(r);
let _ = i;
}
println!("✅ Serial test completed");
results
}
async fn run_concurrent_benchmark(
ffns: &[Arc<TorchFFN>],
batch_size: usize,
requests_num: usize,
concurrent_tasks: usize,
) -> Vec<Duration> {
println!("📊 Starting concurrent test...");
let base_reqs_num = requests_num / concurrent_tasks;
let extra_reqs_num = requests_num % concurrent_tasks;
println!(
" Base requests per task: {base_reqs_num}, extra requests allocated to first {extra_reqs_num} tasks"
);
let mut join_set = JoinSet::new();
for task_id in 0..concurrent_tasks {
let ffns_clone = ffns.to_vec();
let reqs_num = if task_id < extra_reqs_num {
base_reqs_num + 1
} else {
base_reqs_num
};
let reqs_num = if reqs_num == 0 { 1 } else { reqs_num };
join_set.spawn(async move {
let mut task_results = Vec::new();
let inp = ffns_clone[0].rand_input(batch_size);
for i in 0..reqs_num {
let ffn_idx = rand::random_range(0..ffns_clone.len());
let ffn = &ffns_clone[ffn_idx];
let start = Instant::now();
let r = ffn.forward(&inp);
let elapsed = start.elapsed();
task_results.push(elapsed);
let _ = std::hint::black_box(r);
let _ = i;
}
task_results
});
}
let task_results = join_set.join_all().await;
println!("✅ Concurrent test completed");
task_results.into_iter().flatten().collect()
}
fn compare_results(
serial_results: &[Duration],
concurrent_results: &[Duration],
batch_size: usize,
concurrent_tasks: usize,
) {
if serial_results.is_empty() {
println!("❌ Serial test results are empty!");
return;
}
if concurrent_results.is_empty() {
println!("❌ Concurrent test results are empty!");
return;
}
let serial_avg = serial_results.iter().sum::<Duration>() / serial_results.len() as u32;
let concurrent_avg =
concurrent_results.iter().sum::<Duration>() / concurrent_results.len() as u32;
let serial_min = serial_results.iter().min().unwrap();
let serial_max = serial_results.iter().max().unwrap();
let concurrent_min = concurrent_results.iter().min().unwrap();
let concurrent_max = concurrent_results.iter().max().unwrap();
println!();
println!("📈 Performance Comparison Results:");
println!(
" Actual tests: {} serial rounds, {} concurrent rounds",
serial_results.len(),
concurrent_results.len()
);
println!("┌─────────────────────────────────────────┐");
println!("│ Serial Test Results │");
println!("├─────────────────────────────────────────┤");
println!(
"│ Avg latency: {:>8} μs │",
serial_avg.as_micros()
);
println!(
"│ Min latency: {:>8} μs │",
serial_min.as_micros()
);
println!(
"│ Max latency: {:>8} μs │",
serial_max.as_micros()
);
println!(
"│ Per sequence: {:>8.1} μs │",
serial_avg.as_micros() as f64 / batch_size as f64
);
println!("└─────────────────────────────────────────┘");
println!();
println!("┌─────────────────────────────────────────┐");
println!("│ Concurrent Test Results ({concurrent_tasks} tasks) │");
println!("├─────────────────────────────────────────┤");
println!(
"│ Avg latency: {:>8} μs │",
concurrent_avg.as_micros()
);
println!(
"│ Min latency: {:>8} μs │",
concurrent_min.as_micros()
);
println!(
"│ Max latency: {:>8} μs │",
concurrent_max.as_micros()
);
println!(
"│ Per sequence: {:>8.1} μs │",
concurrent_avg.as_micros() as f64 / batch_size as f64
);
println!("└─────────────────────────────────────────┘");
println!();
let slowdown_ratio = concurrent_avg.as_micros() as f64 / serial_avg.as_micros() as f64;
println!("🎯 Key Metrics:");
println!(" • Performance degradation: {slowdown_ratio:.2}x");
println!(
" • Concurrent latency increase: {} μs",
concurrent_avg.as_micros() as i64 - serial_avg.as_micros() as i64
);
println!();
println!("📊 Latency Distribution Analysis:");
print_latency_distribution("Serial", serial_results);
print_latency_distribution("Concurrent", concurrent_results);
}
fn print_latency_distribution(name: &str, results: &[Duration]) {
let mut latencies: Vec<u64> = results.iter().map(|d| d.as_micros() as u64).collect();
latencies.sort();
let len = latencies.len();
let p50 = latencies[len * 50 / 100];
let p90 = latencies[len * 90 / 100];
let p95 = latencies[len * 95 / 100];
let p99 = latencies[len * 99 / 100];
println!(" {name} latency distribution:");
println!(" P50: {p50} μs, P90: {p90} μs, P95: {p95} μs, P99: {p99} μs");
}
}