use std::io::Write;

use clap::Subcommand;
use ek_base::error::EKResult;
use ek_computation::backend::DType;
use ort::session::Session;
use prost::Message;

#[derive(Subcommand, Debug)]
pub enum OnnxCommand {
    Export {
        #[arg(long, default_value_t = ("onnx_model.onnx").to_string())]
        output: String,
        #[arg(long, default_value_t = 7168)]
        hidden_size: i64,
        #[arg(long, default_value_t = 2048)]
        inter_size: i64,
    },
}

pub async fn execute_onnx(cmd: OnnxCommand) -> EKResult<()> {
    match cmd {
        OnnxCommand::Export {
            output,
            hidden_size,
            inter_size,
        } => {
            let builder = ek_computation::onnx::exporter::ExpertOnnxBuilder {
                intermediate_size: inter_size,
                hidden_size,
                data_type: DType::Float,
            };
            let msg = builder.build();
            let mut file = std::fs::File::create(output).unwrap();
            let raw = msg.encode_to_vec();
            file.write_all(raw.as_slice())
                .expect("failed to write to file");
            let _session = Session::builder()
                .expect("failed to create session")
                .commit_from_memory(raw.as_slice())
                .unwrap();
            Ok(())
        }
    }
}