use std::sync::OnceLock;

use agent_types::{ChatMessage, ContentBlock};

pub struct TokenEstimator {
    system_prompt_tokens: OnceLock<usize>,
    tools_tokens: OnceLock<usize>,
}

impl TokenEstimator {
    pub fn new() -> Self {
        Self {
            system_prompt_tokens: OnceLock::new(),
            tools_tokens: OnceLock::new(),
        }
    }

    pub fn estimate_input_tokens(
        &self,
        system_prompt: &str,
        tools_count: usize,
        messages: &[ChatMessage],
    ) -> usize {
        let system_tokens = self.estimate_system_prompt(system_prompt);
        let tools_tokens = self.estimate_tools(tools_count);
        let history_tokens = self.estimate_history_messages(messages);

        system_tokens + tools_tokens + history_tokens
    }

    pub fn estimate_system_prompt(&self, prompt: &str) -> usize {
        *self
            .system_prompt_tokens
            .get_or_init(|| self.estimate_text_precise(prompt))
    }

    pub fn estimate_tools(&self, count: usize) -> usize {
        *self.tools_tokens.get_or_init(|| count * 150 + 200)
    }

    pub fn estimate_history_messages(&self, messages: &[ChatMessage]) -> usize {
        messages.iter().map(|msg| self.estimate_message(msg)).sum()
    }

    pub fn estimate_message(&self, msg: &ChatMessage) -> usize {
        if let Some(cached) = msg.estimated_tokens {
            return cached;
        }

        let mut total = 0;
        for block in &msg.blocks {
            total += self.estimate_block(block);
        }

        if let Some(reasoning) = &msg.reasoning_content {
            total += self.estimate_text_quick(reasoning);
        }

        total + 10
    }

    fn estimate_block(&self, block: &ContentBlock) -> usize {
        match block {
            ContentBlock::Text { text } => self.estimate_text_quick(text),
            ContentBlock::ToolUse { input, .. } => {
                let json_str = serde_json::to_string(input).unwrap_or_default();
                self.estimate_text_quick(&json_str) + 20
            }
            ContentBlock::ToolResult { output, .. } => self.estimate_text_quick(output) + 15,
            ContentBlock::Image { description } => self.estimate_text_quick(description) + 100,
            ContentBlock::Document { description } => self.estimate_text_quick(description) + 150,
        }
    }

    fn estimate_text_precise(&self, text: &str) -> usize {
        let char_count = text.chars().count();
        let byte_count = text.len();

        let ascii_ratio = if char_count > 0 {
            (byte_count as f64 / char_count as f64).min(4.0)
        } else {
            4.0
        };

        if ascii_ratio > 2.5 {
            (char_count / 2).max(1) + 20
        } else {
            (char_count / 4).max(1) + 20
        }
    }

    fn estimate_text_quick(&self, text: &str) -> usize {
        let char_count = text.chars().count();
        let has_chinese = text.chars().any(|c| c > '\u{7F}');

        if has_chinese {
            (char_count / 2) + 10
        } else {
            (char_count / 4) + 10
        }
    }

    pub fn update_message_cache(&self, msg: &mut ChatMessage) {
        if msg.estimated_tokens.is_none() {
            msg.estimated_tokens = Some(self.estimate_message(msg));
        }
    }
}

impl Default for TokenEstimator {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use agent_llm::ChatMessageExt;
    use agent_types::MessageRole;

    #[test]
    fn estimate_text_quick_handles_chinese() {
        let estimator = TokenEstimator::new();
        let chinese_text = "这是一个中文测试文本";
        let tokens = estimator.estimate_text_quick(chinese_text);

        assert!(tokens > 0);
        assert!(tokens < chinese_text.len());
    }

    #[test]
    fn estimate_text_quick_handles_english() {
        let estimator = TokenEstimator::new();
        let english_text = "This is an English test text with multiple words";
        let tokens = estimator.estimate_text_quick(english_text);

        assert!(tokens > 0);
        assert!(tokens < english_text.len());
    }

    #[test]
    fn estimate_message_with_cached_tokens() {
        let estimator = TokenEstimator::new();
        let msg = ChatMessage::text(MessageRole::User, "Test message", 0);

        assert_eq!(msg.estimated_tokens, None);

        let estimated = estimator.estimate_message(&msg);
        assert!(estimated > 0);
    }

    #[test]
    fn estimate_input_tokens_sum_components() {
        let estimator = TokenEstimator::new();
        let messages = vec![
            ChatMessage::text(MessageRole::User, "Hello", 0),
            ChatMessage::text(MessageRole::Assistant, "Hi there", 0),
        ];

        let total = estimator.estimate_input_tokens("System prompt", 5, &messages);

        assert!(total > 0);
    }

    #[test]
    fn system_prompt_is_cached() {
        let estimator = TokenEstimator::new();

        let first = estimator.estimate_system_prompt("Test prompt");
        let second = estimator.estimate_system_prompt("Different prompt");

        assert_eq!(first, second);
    }

    #[test]
    fn tools_tokens_is_cached() {
        let estimator = TokenEstimator::new();

        let first = estimator.estimate_tools(5);
        let second = estimator.estimate_tools(10);

        assert_eq!(first, second);
    }
}