use std::sync::OnceLock;

use crate::backend::{DType, Device, EkTensor, torch::TchTensor};

use super::{
    ExpertWeight,
    meta::{Expert, ExpertShape},
};
use ek_base::error::EKResult;
use tch::{
    self,
    nn::{self, Module},
};

pub struct TorchFFN {
    dim: usize,
    intermediate_dim: usize,
    module: OnceLock<nn::Sequential>,
    weight: ExpertWeight<TchTensor>,
    device: Device,
}

pub fn w8a16_activate(x: &tch::Tensor, s: &tch::Tensor, block_size: i64) -> tch::Tensor {
    let shape = s.size();
    let x_shape = x.size();
    assert!(shape.len() == 2);
    assert!(x_shape.len() == 2);
    let m = shape[0];
    let n = shape[1];
    let pad = x_shape[0] % block_size;
    let s = s.reshape([shape[0], shape[1], 1]);
    let l = if pad > 0 {
        let t = tch::Tensor::zeros([pad, x_shape[1]], (x.kind(), x.device()));

        (tch::Tensor::cat(&[x, &t], 0))
            .reshape([m, block_size, n, block_size])
            .permute([0, 2, 1, 3])
            .reshape([m, n, block_size * block_size])
            .to_kind(tch::Kind::Float)
    } else {
        x.reshape([m, block_size, n, block_size])
            .permute([0, 2, 1, 3])
            .reshape([m, n, block_size * block_size])
            .to_kind(tch::Kind::Float)
    };

    (l * s)
        .to_kind(tch::Kind::BFloat16)
        .reshape([m, n, block_size, block_size])
        .permute([0, 2, 1, 3])
        .reshape(x_shape.clone())
}

unsafe impl Sync for TorchFFN {}

impl TorchFFN {
    pub fn new(
        dim: usize,
        intermediate_dim: usize,
        module: OnceLock<nn::Sequential>,
        weight: ExpertWeight<TchTensor>,
        device: Device,
    ) -> Self {
        TorchFFN {
            dim,
            intermediate_dim,
            module,
            weight,
            device,
        }
    }

    pub fn load_module(&self) -> &nn::Sequential {
        self.module.get_or_init(|| {
            tch::no_grad(|| {
                let w1_tensor = self
                    .weight
                    .up_w
                    .inner()
                    .shallow_clone()
                    .to_kind(tch::Kind::BFloat16)
                    .to_device(self.device.into());
                let w2_tensor = self
                    .weight
                    .down_w
                    .inner()
                    .shallow_clone()
                    .to_kind(tch::Kind::BFloat16)
                    .to_device(self.device.into());
                let w3_tensor = self
                    .weight
                    .gate_w
                    .inner()
                    .shallow_clone()
                    .to_kind(tch::Kind::BFloat16)
                    .to_device(self.device.into());
                nn::seq().add_fn(move |x| {
                    let _up = x.matmul(&w1_tensor.transpose(0, 1));
                    let _gate = x.matmul(&w3_tensor.transpose(0, 1));
                    let _hidden = _up * _gate.silu();
                    _hidden.matmul(&w2_tensor.transpose(0, 1))
                })
            })
        })
    }

    pub fn device(&self) -> Device {
        self.device
    }
}

impl Expert<TchTensor> for TorchFFN {
    fn forward(&self, x: &TchTensor) -> TchTensor {
        let module = self.load_module();
        let res = module.forward(&x.inner());
        TchTensor(res)
    }

    fn rand_input(&self, batch: usize) -> TchTensor {
        TchTensor::rand(vec![batch, self.dim], DType::BFloat16, Device::CPU)
    }
    fn shape(&self) -> ExpertShape {
        ExpertShape {
            hidden: self.dim,
            intermediate: self.intermediate_dim,
        }
    }

    fn backend(&self) -> std::string::String {
        "torch".to_string()
    }

    fn construct(x: crate::x::EKInstance, weight: ExpertWeight<TchTensor>) -> EKResult<Self> {
        Ok(TorchFFN {
            intermediate_dim: x.intermediate,
            dim: x.hidden,
            module: OnceLock::new(),
            weight,
            device: x.device,
        })
    }
}

#[cfg(test)]
mod test {
    use std::fs;

    use ek_base::utils::workspace_root;
    use safetensors::SafeTensors;
    use tch::IndexOp;

    use crate::{
        backend::{Device, EkTensor},
        ffn::{Expert, ExpertWeight, expert_torch::TorchFFN},
        x::{self, test_root},
    };

    use super::{TchTensor, w8a16_activate};

    #[test]
    fn test_io() {
        let rand_t = tch::Tensor::randn(vec![128, 128], (tch::Kind::Float, tch::Device::Cpu));
        let target = TchTensor::from(rand_t.copy());
        let bytes = target.serialize();
        let st = SafeTensors::deserialize(&bytes).unwrap();
        let tv = st.tensor("data").unwrap();
        let processed = TchTensor::from_tensor_view(&tv);
        assert!(processed.inner().sum(tch::Kind::Float) == rand_t.sum(tch::Kind::Float))
    }

    #[test]
    fn test_correctness() {
        let st_fp = test_root()
            .join("resources")
            .join("qwen3-l0e1.weight.safetensors");
        let st_bytes = fs::read(st_fp).unwrap();
        let st = SafeTensors::deserialize(&st_bytes).unwrap();
        let weight = ExpertWeight::from_safetensor(&st, Device::CPU).unwrap();
        let inst = x::EKInstance {
            hidden: 2048,
            intermediate: 768,
            backend: x::ExpertBackendType::Torch,
            device: Device::CPU,
        };
        let ffn = TorchFFN::construct(inst, weight).unwrap();

        let ground_truth_fp = test_root()
            .join("resources")
            .join("qwen3-l0e1.result.safetensors");
        let ground_truth_bytes = fs::read(ground_truth_fp).unwrap();
        let gt_st = SafeTensors::deserialize(&ground_truth_bytes).unwrap();

        let tv = gt_st.tensor("1-input").unwrap();
        let inp = TchTensor::from_tensor_view(&tv);

        let res = ffn.forward(&inp).inner();
        let truth = TchTensor::from_tensor_view(&gt_st.tensor("1-output").unwrap()).inner();

        let _vec1 = Vec::<f32>::try_from(res.i((0, 0..100))).unwrap();
        let _vec2 = Vec::<f32>::try_from(truth.i((0, 0..100))).unwrap();
        (res - truth).sum(tch::Kind::BFloat16).print();
    }

    #[test]
    fn test_fp8_dequant() {
        let st_fp = workspace_root()
            .join("ek-computation")
            .join("resources")
            .join("w8a16active-l0q_a_proj.safetensors");
        let st_bytes = fs::read(st_fp).unwrap();
        let st = SafeTensors::deserialize(&st_bytes).unwrap();
        let tv1 = st.tensor("src").unwrap();
        let tv2 = st.tensor("src_scale").unwrap();
        let expected = st.tensor("triton_dequanted").unwrap();
        let tv1 = TchTensor::from_tensor_view(&tv1).inner();
        let tv2 = TchTensor::from_tensor_view(&tv2).inner();
        let expected = TchTensor::from_tensor_view(&expected).inner();
        let res = w8a16_activate(&tv1, &tv2, 128);
        let diff = (res - expected)
            .sum(tch::Kind::Double)
            .abs()
            .double_value(&[]);
        assert!(diff < 0.2);
    }
}

#[cfg(test)]
mod bench_ffn_concurrent {
    use super::*;
    use std::sync::Arc;
    use std::time::{Duration, Instant};
    use tokio::task::JoinSet;

    #[tokio::test]
    async fn bench_torch_ffn_concurrent_vs_serial() {
        // Configs
        let ffn_count = 64;
        let batch_size = 1; // Size of each request
        let reqests_num = 512; // Reqests processed in per round
        let thread_num = 64; // Num of concurrent tasks, requests will be evenly distributed across these tasks

        // Randomly generate FFNs
        let mut ffns = Vec::new();
        for _ in 0..ffn_count {
            let ffn = TorchFFN {
                dim: 2048,
                intermediate_dim: 768,
                weight: ExpertWeight::from_rand_linear(
                    2048,
                    768,
                    crate::backend::DType::BFloat16,
                    crate::backend::Device::CPU,
                ),
                module: OnceLock::new(),
                device: Device::CPU,
            };
            ffns.push(ffn);
        }
        // Convert to Arc for sharing across threads
        let ffns: Vec<Arc<TorchFFN>> = ffns.into_iter().map(Arc::new).collect();

        println!("🔥 FFN Concurrent vs Serial Performance Comparison Test");
        println!("  FFN count: {ffn_count}");
        println!("  Batch size: {batch_size}");
        println!("  Requests num: {reqests_num}");
        println!("  Thread num: {thread_num}");
        println!();

        // Warm up all FFNs
        warm_up_ffns(&ffns, batch_size).await;

        // 1. Serial test
        let serial_results = run_serial_benchmark(&ffns, batch_size, reqests_num).await;

        // 2. Concurrent test
        let concurrent_results =
            run_concurrent_benchmark(&ffns, batch_size, reqests_num, thread_num).await;

        // 3. Compare results
        compare_results(&serial_results, &concurrent_results, batch_size, thread_num);
    }

    async fn warm_up_ffns(ffns: &[Arc<TorchFFN>], batch_size: usize) {
        println!("🚀 Warming up {} FFNs...", ffns.len());
        for ffn in ffns {
            let inp = ffn.rand_input(batch_size);
            let _ = ffn.forward(&inp);
        }
        println!("✅ Warmup completed");
        println!();
    }

    async fn run_serial_benchmark(
        ffns: &[Arc<TorchFFN>],
        batch_size: usize,
        requests_num: usize,
    ) -> Vec<Duration> {
        println!("📊 Starting serial test...");
        let mut results = Vec::new();
        let inp = ffns[0].rand_input(batch_size);

        for i in 0..requests_num {
            // Randomly select a FFN
            let ffn_idx = rand::random_range(0..ffns.len());
            let ffn = &ffns[ffn_idx];

            let start = Instant::now();
            let r = ffn.forward(&inp);
            let elapsed = start.elapsed();
            results.push(elapsed);

            let _ = std::hint::black_box(r);

            let _ = i;
            // if (i + 1) % 20 == 0 {
            //     println!("  Serial progress: {}/{}", i + 1, requests_num);
            // }
        }

        println!("✅ Serial test completed");
        results
    }

    async fn run_concurrent_benchmark(
        ffns: &[Arc<TorchFFN>],
        batch_size: usize,
        requests_num: usize,
        concurrent_tasks: usize,
    ) -> Vec<Duration> {
        println!("📊 Starting concurrent test...");
        // let mut all_results = Vec::new();

        // Ensure each task executes at least once with even distribution
        let base_reqs_num = requests_num / concurrent_tasks;
        let extra_reqs_num = requests_num % concurrent_tasks;

        println!(
            "  Base requests per task: {base_reqs_num}, extra requests allocated to first {extra_reqs_num} tasks"
        );

        let mut join_set = JoinSet::new();

        // Launch multiple concurrent tasks
        for task_id in 0..concurrent_tasks {
            let ffns_clone = ffns.to_vec();

            // First few tasks get an extra round
            let reqs_num = if task_id < extra_reqs_num {
                base_reqs_num + 1
            } else {
                base_reqs_num
            };

            // If base_rounds is 0 and this task has no extra rounds, execute at least once
            let reqs_num = if reqs_num == 0 { 1 } else { reqs_num };

            join_set.spawn(async move {
                let mut task_results = Vec::new();
                let inp = ffns_clone[0].rand_input(batch_size);

                for i in 0..reqs_num {
                    // Randomly select a FFN for this task
                    let ffn_idx = rand::random_range(0..ffns_clone.len());
                    let ffn = &ffns_clone[ffn_idx];

                    let start = Instant::now();
                    let r = ffn.forward(&inp);
                    let elapsed = start.elapsed();
                    task_results.push(elapsed);

                    let _ = std::hint::black_box(r);

                    let _ = i;
                    // if reqs_num > 10 && (i + 1) % 10 == 0 {
                    //     println!("  Concurrent task {} progress: {}/{}", task_id, i + 1, reqs_num);
                    // }
                }

                // println!("  Task {} completed: {} executions", task_id, reqs_num);
                task_results
            });
        }

        // join_all tasks at once
        let task_results = join_set.join_all().await;

        println!("✅ Concurrent test completed");
        task_results.into_iter().flatten().collect()
    }

    fn compare_results(
        serial_results: &[Duration],
        concurrent_results: &[Duration],
        batch_size: usize,
        concurrent_tasks: usize,
    ) {
        // Check if results are empty
        if serial_results.is_empty() {
            println!("❌ Serial test results are empty!");
            return;
        }
        if concurrent_results.is_empty() {
            println!("❌ Concurrent test results are empty!");
            return;
        }

        let serial_avg = serial_results.iter().sum::<Duration>() / serial_results.len() as u32;
        let concurrent_avg =
            concurrent_results.iter().sum::<Duration>() / concurrent_results.len() as u32;

        let serial_min = serial_results.iter().min().unwrap();
        let serial_max = serial_results.iter().max().unwrap();
        let concurrent_min = concurrent_results.iter().min().unwrap();
        let concurrent_max = concurrent_results.iter().max().unwrap();

        println!();
        println!("📈 Performance Comparison Results:");
        println!(
            "  Actual tests: {} serial rounds, {} concurrent rounds",
            serial_results.len(),
            concurrent_results.len()
        );
        println!("┌─────────────────────────────────────────┐");
        println!("│            Serial Test Results          │");
        println!("├─────────────────────────────────────────┤");
        println!(
            "│ Avg latency: {:>8} μs                │",
            serial_avg.as_micros()
        );
        println!(
            "│ Min latency: {:>8} μs                │",
            serial_min.as_micros()
        );
        println!(
            "│ Max latency: {:>8} μs                │",
            serial_max.as_micros()
        );
        println!(
            "│ Per sequence: {:>8.1} μs               │",
            serial_avg.as_micros() as f64 / batch_size as f64
        );
        println!("└─────────────────────────────────────────┘");

        println!();
        println!("┌─────────────────────────────────────────┐");
        println!("│    Concurrent Test Results ({concurrent_tasks} tasks)   │");
        println!("├─────────────────────────────────────────┤");
        println!(
            "│ Avg latency:    {:>8} μs             │",
            concurrent_avg.as_micros()
        );
        println!(
            "│ Min latency:    {:>8} μs             │",
            concurrent_min.as_micros()
        );
        println!(
            "│ Max latency:    {:>8} μs             │",
            concurrent_max.as_micros()
        );
        println!(
            "│ Per sequence:   {:>8.1} μs             │",
            concurrent_avg.as_micros() as f64 / batch_size as f64
        );
        println!("└─────────────────────────────────────────┘");

        println!();
        let slowdown_ratio = concurrent_avg.as_micros() as f64 / serial_avg.as_micros() as f64;
        println!("🎯 Key Metrics:");
        println!("  • Performance degradation: {slowdown_ratio:.2}x");
        println!(
            "  • Concurrent latency increase: {} μs",
            concurrent_avg.as_micros() as i64 - serial_avg.as_micros() as i64
        );

        // Latency distribution analysis
        println!();
        println!("📊 Latency Distribution Analysis:");
        print_latency_distribution("Serial", serial_results);
        print_latency_distribution("Concurrent", concurrent_results);
    }

    fn print_latency_distribution(name: &str, results: &[Duration]) {
        let mut latencies: Vec<u64> = results.iter().map(|d| d.as_micros() as u64).collect();
        latencies.sort();

        let len = latencies.len();
        let p50 = latencies[len * 50 / 100];
        let p90 = latencies[len * 90 / 100];
        let p95 = latencies[len * 95 / 100];
        let p99 = latencies[len * 99 / 100];

        println!("  {name} latency distribution:");
        println!("    P50: {p50} μs, P90: {p90} μs, P95: {p95} μs, P99: {p99} μs");
    }
}