use agent_contracts::TokenEstimator;
use agent_types::{ChatMessage, ContentBlock};
use serde::{Deserialize, Serialize};

use crate::{CompactError, CompactResult};

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct RoughTokenEstimatorConfig {
    pub chars_per_token: usize,
    pub message_overhead_tokens: usize,
    pub tool_use_overhead_tokens: usize,
    pub tool_result_overhead_tokens: usize,
    pub image_block_overhead_tokens: usize,
    pub document_block_overhead_tokens: usize,
}

pub struct RoughTokenEstimator {
    config: RoughTokenEstimatorConfig,
}

impl RoughTokenEstimator {
    pub fn try_new(config: RoughTokenEstimatorConfig) -> CompactResult<Self> {
        if config.chars_per_token == 0 {
            return Err(CompactError::InvalidConfiguration {
                message: "chars_per_token must be greater than zero".to_string(),
            });
        }

        Ok(Self { config })
    }

    pub fn config(&self) -> &RoughTokenEstimatorConfig {
        &self.config
    }
}

impl TokenEstimator for RoughTokenEstimator {
    fn estimate_message_tokens(&self, message: &ChatMessage) -> usize {
        self.config.message_overhead_tokens
            + message
                .blocks
                .iter()
                .map(|block| match block {
                    ContentBlock::Text { text } => self.estimate_text_tokens(text),
                    ContentBlock::ToolUse {
                        tool_name, input, ..
                    } => {
                        self.config.tool_use_overhead_tokens
                            + self.estimate_text_tokens(tool_name)
                            + self.estimate_text_tokens(&input.to_string())
                    }
                    ContentBlock::ToolResult {
                        tool_name, output, ..
                    } => {
                        self.config.tool_result_overhead_tokens
                            + self.estimate_text_tokens(tool_name)
                            + self.estimate_text_tokens(output)
                    }
                    ContentBlock::Image { description } => {
                        self.config.image_block_overhead_tokens
                            + self.estimate_text_tokens(description)
                    }
                    ContentBlock::Document { description } => {
                        self.config.document_block_overhead_tokens
                            + self.estimate_text_tokens(description)
                    }
                })
                .sum::<usize>()
    }

    fn estimate_messages_tokens(&self, messages: &[ChatMessage]) -> usize {
        messages
            .iter()
            .map(|message| self.estimate_message_tokens(message))
            .sum()
    }

    fn estimate_text_tokens(&self, text: &str) -> usize {
        let characters = text.chars().count();
        if characters == 0 {
            return 0;
        }

        characters.div_ceil(self.config.chars_per_token)
    }
}