use std::{collections::HashMap, error::Error, path::PathBuf};

use clap::Parser;
use fusor_gguf::GgufValue;
use parse::parse_key_val;

mod parse;

#[derive(clap::Parser, Clone)]
#[clap(name = "fusor-ml", version = "0.1.0", author = "Fusor ML Team")]
enum Command {
    /// Add metadata to a gguf file
    AddMetadata {
        /// Path to the input gguf file
        #[clap(short, long)]
        input: PathBuf,
        /// Path to the output gguf file
        #[clap(short, long)]
        output: PathBuf,
        /// Metadata to add. Format: KEY=VALUE
        #[arg(short = 'D', value_parser = parse_key_val::<String>)]
        data: Vec<(String, GgufValue)>,
    },
    /// Fuse the tokenizer metadata into the gguf file
    FuseTokenizer {
        /// Path to the input gguf file
        #[clap(short, long)]
        input: PathBuf,
        /// Path to the tokenizer file
        #[clap(short, long)]
        tokenizer: PathBuf,
        /// Path to the tokenizer config file
        #[clap(short, long)]
        tokenizer_config: Option<PathBuf>,
        /// Path to the output gguf file
        #[clap(short, long)]
        output: PathBuf,
    },
}

fn main() {
    let command = Command::parse();

    match command {
        Command::AddMetadata {
            input,
            output,
            data,
        } => add_metadata(input, output, data).unwrap(),
        Command::FuseTokenizer {
            input,
            tokenizer,
            output,
            tokenizer_config,
        } => fuse_tokenizer(input, tokenizer, output, tokenizer_config).unwrap(),
    }
}

fn fuse_tokenizer(
    input: PathBuf,
    tokenizer: PathBuf,
    output: PathBuf,
    tokenizer_config: Option<PathBuf>,
) -> Result<(), Box<dyn Error>> {
    let mut metadata = Vec::new();

    let tokenizer_json: serde_json::Value =
        serde_json::from_reader(std::fs::File::open(&tokenizer)?)?;
    // Store model.vocab in tokenizer.ggml.tokens
    let token_map = tokenizer_json["model"]["vocab"].as_object().unwrap();
    let added_tokens = tokenizer_json["added_tokens"].as_array().unwrap();
    let mut tokens = token_map
        .iter()
        .map(|(token, id)| (token.clone(), id.as_u64().unwrap()))
        .collect::<Vec<_>>();
    // Add added_tokens to the token map
    for token in added_tokens {
        let id = token["id"].as_u64().unwrap();
        let token_str = token["content"].as_str().unwrap();
        tokens.push((token_str.to_string(), id));
    }
    // Find the max id
    let max_id = tokens.iter().map(|(_, id)| *id).max().unwrap();
    let mut tokens_array = vec!["<conversion-error>".to_string(); (max_id + 1) as usize];
    for (token, id) in &tokens {
        tokens_array[*id as usize] = token.clone();
    }
    let tokens = GgufValue::Array(
        tokens_array
            .iter()
            .map(|s| GgufValue::String(s.clone().into_boxed_str()))
            .collect(),
    );
    metadata.push(("tokenizer.ggml.tokens".into(), tokens));

    // Try to find bos_token and eos_token in tokenizer_config
    if let Some(config) = tokenizer_config {
        let config_json: serde_json::Value =
            serde_json::from_reader(std::fs::File::open(&config)?)?;
        if let Some(bos_token) = config_json["bos_token"].as_str() {
            let id = tokens_array
                .iter()
                .position(|s| s == bos_token)
                .ok_or_else(|| format!("Invalid bos_token: {bos_token}"))?;
            metadata.push((
                "tokenizer.ggml.bos_token_id".into(),
                GgufValue::U32(id as u32),
            ));
        }
        if let Some(eos_token) = config_json["eos_token"].as_str() {
            let id = tokens_array
                .iter()
                .position(|s| s == eos_token)
                .ok_or_else(|| format!("Invalid eos_token: {eos_token}"))?;
            metadata.push((
                "tokenizer.ggml.eos_token_id".into(),
                GgufValue::U32(id as u32),
            ));
        }
    }

    add_metadata(input, output, metadata)?;

    Ok(())
}

fn add_metadata(
    input: PathBuf,
    output: PathBuf,
    data: Vec<(String, GgufValue)>,
) -> Result<(), Box<dyn Error>> {
    let mut tensor;
    let mut gguf;
    {
        let reader = std::fs::File::open(&input)?;
        let mut reader = std::io::BufReader::new(reader);
        gguf = fusor_gguf::GgufMetadata::read(&mut reader)?;
        gguf.metadata
            .extend(data.into_iter().map(|(k, v)| (k.into(), v)));
        tensor = HashMap::new();
        for (key, value) in gguf.tensor_infos.iter() {
            tensor.insert(
                key.clone(),
                value.read_tensor_bytes(&mut reader, gguf.tensor_data_offset)?,
            );
        }
    }
    let writer = std::fs::File::create(&output)?;
    let mut writer = std::io::BufWriter::new(writer);
    gguf.write(
        &mut writer,
        tensor.iter().map(|(k, v)| (k.as_ref(), v.as_ref())),
    )?;
    Ok(())
}