#!/usr/bin/env python3
"""
从 https://models.dev/api.json 获取 LLM 价格数据,
生成 internal/price/presets.go 文件
"""

import json
import re
import urllib.request
from datetime import datetime, timezone
from pathlib import Path

LLM_PRICE_URL = "https://models.dev/api.json"

PROVIDERS = [
    "openai",      # GPT 系列
    "anthropic",   # Claude 系列
    "google",      # Gemini 系列
    "deepseek",    # DeepSeek 系列
    "xai",         # Grok 系列
    "alibaba",     # Qwen 系列
    "zhipuai",     # GLM 系列
    "minimax",     # MiniMax 系列
    "moonshotai",  # Kimi/Moonshot
    "v0",          # v0 系列
]

# 其他模型别名映射 (非 Claude)
MODEL_ALIASES: dict[str, list[str]] = {
    # 在这里添加其他模型的别名
}

PRESETS_GO_TEMPLATE = '''package price

// Code generated by scripts/updatePrice.py. DO NOT EDIT.
// Last updated: {update_time}

import (
	"sync"

	"github.com/bestruirui/octopus/internal/model"
)

var llmPriceLock sync.RWMutex

var llmPrice = map[string]model.LLMPrice{{
{entries}
}}
'''


def fetch_price_data() -> dict:
    """从 API 获取价格数据"""
    req = urllib.request.Request(
        LLM_PRICE_URL, 
        headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"}
    )
    with urllib.request.urlopen(req) as response:
        return json.loads(response.read().decode("utf-8"))


def format_price(value: float | None) -> str:
    """格式化价格值,去除不必要的尾部零"""
    if value is None:
        return "0"
    if value == 0:
        return "0"
    # 使用 repr 避免浮点数精度问题,然后去除尾部零
    formatted = f"{value:.10f}".rstrip("0").rstrip(".")
    return formatted


def generate_claude_aliases(model_id: str) -> list[str]:
    """
    为 Claude 模型自动生成别名
    
    规则:
    1. claude-{type}-{major}-{minor}[-suffix] -> claude-{type}-{major}.{minor}[-suffix]
       例: claude-opus-4-5 -> claude-opus-4.5
    2. claude-{type}-{major}-{minor}[-suffix] -> claude-{major}.{minor}-{type}[-suffix]
       例: claude-opus-4-5 -> claude-4.5-opus
    3. claude-{major}-{minor}-{type}[-suffix] -> claude-{major}.{minor}-{type}[-suffix]
       例: claude-3-5-sonnet-20241022 -> claude-3.5-sonnet-20241022
    """
    if not model_id.startswith("claude-"):
        return []
    
    aliases = []
    
    # 模式1: claude-{type}-{major}-{minor}[-suffix]
    # 匹配: claude-opus-4-5, claude-opus-4-5-20251101, claude-sonnet-4-0, claude-haiku-4-5
    # 注意: minor 版本号只能是单个数字,避免把日期当成版本号
    pattern1 = re.compile(r"^claude-(opus|sonnet|haiku)-(\d)-(\d)(-.*)?$")
    match1 = pattern1.match(model_id)
    if match1:
        model_type = match1.group(1)
        major = match1.group(2)
        minor = match1.group(3)
        suffix = match1.group(4) or ""
        
        # 别名1: claude-{type}-{major}.{minor}[-suffix]
        alias1 = f"claude-{model_type}-{major}.{minor}{suffix}"
        aliases.append(alias1)
        
        # 别名2: claude-{major}.{minor}-{type}[-suffix]
        alias2 = f"claude-{major}.{minor}-{model_type}{suffix}"
        aliases.append(alias2)
        
        # 别名3: claude-{major}-{minor}-{type}[-suffix]
        alias3 = f"claude-{major}-{minor}-{model_type}{suffix}"
        aliases.append(alias3)
        
        return aliases
    
    # 模式2: claude-{major}-{minor}-{type}[-suffix]
    # 匹配: claude-3-5-sonnet-20241022, claude-3-7-sonnet-latest, claude-3-5-haiku-20241022
    # 注意: major/minor 版本号只能是单个数字
    pattern2 = re.compile(r"^claude-(\d)-(\d)-(opus|sonnet|haiku)(-.*)?$")
    match2 = pattern2.match(model_id)
    if match2:
        major = match2.group(1)
        minor = match2.group(2)
        model_type = match2.group(3)
        suffix = match2.group(4) or ""
        
        # 别名1: claude-{major}.{minor}-{type}[-suffix]
        alias1 = f"claude-{major}.{minor}-{model_type}{suffix}"
        aliases.append(alias1)
        
        # 别名2: claude-{type}-{major}.{minor}[-suffix]
        alias2 = f"claude-{model_type}-{major}.{minor}{suffix}"
        aliases.append(alias2)
        
        return aliases
    
    return aliases


def generate_entry(model_id: str, cost: dict) -> str:
    """生成单个模型的 Go map entry"""
    input_price = format_price(cost.get("input"))
    output_price = format_price(cost.get("output"))
    cache_read = format_price(cost.get("cache_read"))
    cache_write = format_price(cost.get("cache_write"))
    
    return f'\t"{model_id}": {{Input: {input_price}, Output: {output_price}, CacheRead: {cache_read}, CacheWrite: {cache_write}}},'


def main():
    print(f"Fetching price data from {LLM_PRICE_URL}...")
    raw_price = fetch_price_data()
    
    entries = []
    model_count = 0
    
    for provider in PROVIDERS:
        if provider not in raw_price:
            print(f"  Provider '{provider}' not found, skipping...")
            continue
            
        models = raw_price[provider].get("models", {})
        provider_count = 0
        
        for model_data in models.values():
            model_id = model_data.get("id", "").lower()
            cost = model_data.get("cost", {})
            
            if not model_id:
                continue
            
            # 添加原始模型
            entries.append(generate_entry(model_id, cost))
            provider_count += 1
            
            # 收集所有别名
            aliases = []
            
            # 1. Claude 模型自动生成别名
            aliases.extend(generate_claude_aliases(model_id))
            
            # 2. 静态别名映射
            if model_id in MODEL_ALIASES:
                aliases.extend(MODEL_ALIASES[model_id])
            
            # 添加别名 (去重)
            for alias in set(aliases):
                entries.append(generate_entry(alias.lower(), cost))
                provider_count += 1
            
        print(f"  {provider}: {provider_count} models")
        model_count += provider_count
    
    # 生成 Go 文件内容
    update_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
    content = PRESETS_GO_TEMPLATE.format(
        update_time=update_time,
        entries="\n".join(entries),
    )
    
    # 写入文件
    script_dir = Path(__file__).parent
    output_path = script_dir.parent / "internal" / "price" / "presets.go"
    
    output_path.write_text(content, encoding="utf-8")
    print(f"\nGenerated {output_path} with {model_count} models")


if __name__ == "__main__":
    main()