use std::sync::OnceLock;

use criterion::{BatchSize, Criterion};
use ek_computation::{
    backend::{Device, EkTensor},
    ffn::{
        expert_ggml::GgmlFFN,
        expert_torch::TorchFFN,
        meta::{Expert, ExpertWeight},
    },
};

use crate::{BACKEND2DEVICES, DEVICES};

const BATCH_SIZES: &[usize] = &[1, 4, 8, 16, 32, 64, 128, 256, 512];

pub fn bench(c: &mut Criterion) {
    let mut group = c.benchmark_group("ffn w/o weight transfer");

    for &batch_size in BATCH_SIZES {
        for &backend in BACKEND2DEVICES.keys() {
            for &dev in &BACKEND2DEVICES[backend] {
                group.bench_with_input(
                    format!("batch={batch_size}, backend={backend} ({dev})"),
                    &batch_size,
                    |b, &batch_size| {
                        if backend == "ggml" {
                            let weight = ExpertWeight::from_rand_linear(
                                2048,
                                768,
                                ek_computation::backend::DType::BFloat16,
                                DEVICES[dev],
                            );
                            let ffn = GgmlFFN::new(2048, 768, weight, 8);
                            b.iter_batched(
                                || ffn.rand_input(batch_size).to_device(Device::CPU),
                                |input| {
                                    let _ = std::hint::black_box(
                                        ffn.forward(&input.to_device(DEVICES[dev]))
                                            .to_device(Device::CPU),
                                    );
                                },
                                BatchSize::PerIteration,
                            );
                        } else if backend == "torch" {
                            let weight = ExpertWeight::from_rand_linear(
                                2048,
                                768,
                                ek_computation::backend::DType::BFloat16,
                                DEVICES[dev],
                            );
                            let ffn =
                                TorchFFN::new(2048, 768, OnceLock::new(), weight, DEVICES[dev]);
                            b.iter_batched(
                                || ffn.rand_input(batch_size).to_device(Device::CPU),
                                |input| {
                                    let _ = std::hint::black_box(
                                        ffn.forward(&input.to_device(DEVICES[dev]))
                                            .to_device(Device::CPU),
                                    );
                                },
                                BatchSize::PerIteration,
                            );
                        }
                    },
                );
            }
        }
    }
}