#!/usr/bin/env python3
"""Prefetch model config files required by tests into a target cache directory."""

from __future__ import annotations

import argparse
import ast
import contextlib
import json
import logging
import os
import re
import sys
from collections.abc import Iterator, Sequence
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Protocol

SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parent
DEFAULT_SCAN_DIR = REPO_ROOT / "tests"
DEFAULT_DEST_DIR = REPO_ROOT / "tests" / "assets" / "cache"

_MODELSCOPE_WEIGHT_IGNORE_PATTERNS = [
    "*.safetensors",
    "*.safetensors.index.json",
    "*.bin",
    "*.pt",
    "*.pth",
    "*.ckpt",
    "*.h5",
    "*.npz",
    "*.onnx",
    "*.gguf",
    "*.zip",
    "*.tar",
    "*.tar.gz",
]

_IGNORE_PREFIXES = (
    "tests/",
    "tensor_cast/",
    "serving_cast/",
    "trace/",
    "docs/",
    "./",
    "../",
    "http://",
    "https://",
)
_IGNORE_OWNERS = frozenset({"tests", "tensor_cast", "serving_cast", "trace", "docs", "web_ui"})
_MODEL_ID_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*/[A-Za-z0-9][A-Za-z0-9_.-]*$")
_KNOWN_EXTENSIONS = frozenset({".yaml", ".yml", ".json", ".py", ".md", ".txt", ".csv"})


@dataclass(frozen=True, slots=True)
class PrefetchResult:
    model_id: str
    source: str
    success: bool
    error: str = ""

    def to_dict(self) -> dict[str, str | bool]:
        return asdict(self)


@dataclass(frozen=True, slots=True)
class EnvOverrides:
    """Environment variables to set during prefetch operations."""

    hf_home: str
    torch_home: str
    modelscope_cache: str

    @contextlib.contextmanager
    def activate(self) -> Iterator[None]:
        env_keys = (
            "HF_HOME",
            "TORCH_HOME",
            "MODELSCOPE_CACHE",
            "MSMODELING_OFFLINE",
            "HF_HUB_OFFLINE",
            "TRANSFORMERS_OFFLINE",
            "HF_DATASETS_OFFLINE",
        )
        old = {k: os.environ.get(k) for k in env_keys}
        os.environ["HF_HOME"] = self.hf_home
        os.environ["TORCH_HOME"] = self.torch_home
        os.environ["MODELSCOPE_CACHE"] = self.modelscope_cache
        os.environ["MSMODELING_OFFLINE"] = "0"
        os.environ["HF_HUB_OFFLINE"] = "0"
        os.environ["TRANSFORMERS_OFFLINE"] = "0"
        os.environ["HF_DATASETS_OFFLINE"] = "0"
        try:
            yield
        finally:
            for k, v in old.items():
                if v is None:
                    os.environ.pop(k, None)
                else:
                    os.environ[k] = v


class ConfigPrefetcher(Protocol):
    """Strategy for downloading and loading a model config from one source."""

    def fetch(self, model_id: str) -> PrefetchResult: ...


class HuggingFacePrefetcher:
    def __init__(self) -> None:
        from transformers import AutoConfig

        self._AutoConfig = AutoConfig

    def fetch(self, model_id: str) -> PrefetchResult:
        try:
            self._AutoConfig.from_pretrained(model_id)
        except Exception as exc:
            if "trust_remote_code" not in str(exc):
                raise
            self._AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        return PrefetchResult(model_id=model_id, source="huggingface", success=True)


class ModelScopePrefetcher:
    def __init__(self) -> None:
        import modelscope

        self._AutoConfig = modelscope.AutoConfig
        self._snapshot_download = modelscope.snapshot_download

    def fetch(self, model_id: str) -> PrefetchResult:
        kwargs = self._build_snapshot_kwargs(model_id)
        local_dir = self._snapshot_download(model_id, **kwargs)
        try:
            self._AutoConfig.from_pretrained(local_dir)
        except Exception as exc:
            if "trust_remote_code" not in str(exc):
                raise
            self._AutoConfig.from_pretrained(local_dir, trust_remote_code=True)
        return PrefetchResult(model_id=model_id, source="modelscope", success=True)

    def _build_snapshot_kwargs(self, model_id: str) -> dict[str, Any]:
        """Detect which ignore-pattern kwarg the installed modelscope accepts."""
        import inspect

        sig = inspect.signature(self._snapshot_download)  # pylint: disable=no-member
        if "ignore_file_pattern" in sig.parameters:
            return {"ignore_file_pattern": _MODELSCOPE_WEIGHT_IGNORE_PATTERNS}
        return {"ignore_patterns": _MODELSCOPE_WEIGHT_IGNORE_PATTERNS}


def _looks_like_model_id(value: str) -> bool:
    text = value.strip()
    if not text or "/" not in text or "\\" in text or " " in text:
        return False
    if not _MODEL_ID_PATTERN.fullmatch(text):
        return False
    if text.startswith(_IGNORE_PREFIXES):
        return False
    owner, _, name = text.partition("/")
    if len(owner) < 2 or len(name) < 2:
        return False
    if owner in _IGNORE_OWNERS:
        return False
    if not any(ch.isalpha() for ch in owner) or not any(ch.isalpha() for ch in name):
        return False
    if name.endswith(tuple(_KNOWN_EXTENSIONS)):
        return False
    # Reject names with a plausible file extension suffix
    dot_idx = name.rfind(".")
    if dot_idx != -1:
        suffix = name[dot_idx + 1 :]
        if suffix.isalpha() and len(suffix) <= 5:
            return False
    return True


def _iter_string_values(data: Any) -> Iterator[str]:
    if isinstance(data, str):
        yield data
    elif isinstance(data, dict):
        for value in data.values():
            yield from _iter_string_values(value)
    elif isinstance(data, list):
        for item in data:
            yield from _iter_string_values(item)


def _collect_from_python(path: Path) -> set[str]:
    try:
        tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
    except (OSError, SyntaxError):
        return set()
    found: set[str] = set()
    for node in ast.walk(tree):
        if isinstance(node, ast.Constant) and isinstance(node.value, str):
            candidate = node.value.strip()
            if _looks_like_model_id(candidate):
                found.add(candidate)
    return found


def _collect_from_json(path: Path) -> set[str]:
    try:
        data = json.loads(path.read_text(encoding="utf-8"))
    except (OSError, json.JSONDecodeError):
        return set()
    found: set[str] = set()
    for value in _iter_string_values(data):
        if _looks_like_model_id(value):
            found.add(value.strip())
    return found


def collect_model_ids(scan_dir: Path) -> list[str]:
    model_ids: set[str] = set()
    for path in scan_dir.rglob("*"):
        if not path.is_file():
            continue
        try:
            rel = path.relative_to(REPO_ROOT).as_posix()
        except ValueError:
            try:
                rel = path.relative_to(scan_dir).as_posix()
            except ValueError:
                continue
        if rel.startswith(("tests/.ci/", "tests/assets/cache/", "scripts/helpers/")):
            continue
        if path.suffix == ".py":
            model_ids.update(_collect_from_python(path))
        elif path.suffix == ".json":
            model_ids.update(_collect_from_json(path))
    return sorted(model_ids)


def _build_prefetchers() -> list[ConfigPrefetcher]:
    prefetchers: list[ConfigPrefetcher] = []
    try:
        prefetchers.append(HuggingFacePrefetcher())
    except ImportError:
        pass
    try:
        prefetchers.append(ModelScopePrefetcher())
    except ImportError:
        pass
    return prefetchers


def _prefetch_all(
    model_ids: Sequence[str],
    prefetchers: Sequence[ConfigPrefetcher],
) -> list[PrefetchResult]:
    results: list[PrefetchResult] = []
    for model_id in model_ids:
        result = _try_prefetch(model_id, prefetchers)
        results.append(result)
    return results


def _try_prefetch(
    model_id: str,
    prefetchers: Sequence[ConfigPrefetcher],
) -> PrefetchResult:
    last_error = ""
    for prefetcher in prefetchers:
        try:
            return prefetcher.fetch(model_id)
        except Exception as exc:
            last_error = str(exc)
    return PrefetchResult(
        model_id=model_id,
        source="unresolved",
        success=False,
        error=last_error,
    )


def _write_manifest(dest_dir: Path, scan_dir: Path, results: list[PrefetchResult]) -> Path:
    manifest = {
        "schema_version": 1,
        "scan_dir": str(scan_dir),
        "dest_dir": str(dest_dir),
        "models": [r.to_dict() for r in results],
    }
    manifest_path = dest_dir / "model_config_manifest.json"
    manifest_path.write_text(
        json.dumps(manifest, indent=2, ensure_ascii=False) + "\n",
        encoding="utf-8",
    )
    return manifest_path


def main() -> int:
    logging.basicConfig(
        level=logging.INFO,
        format="%(levelname)s %(message)s",
        stream=sys.stderr,
    )
    logger = logging.getLogger(__name__)

    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--dest-dir",
        default=str(DEFAULT_DEST_DIR),
        help="Directory used as HF_HOME/TORCH_HOME/MODELSCOPE_CACHE for config prefetch.",
    )
    parser.add_argument(
        "--scan-dir",
        default=str(DEFAULT_SCAN_DIR),
        help="Directory to scan for model ids. Default: tests/",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Only discover model ids and write manifest without downloading configs.",
    )
    args = parser.parse_args()

    dest_dir = Path(args.dest_dir).expanduser().resolve()
    scan_dir = Path(args.scan_dir).expanduser().resolve()
    if not scan_dir.exists():
        logger.error("scan dir not found: %s", scan_dir)
        return 2

    dest_dir.mkdir(parents=True, exist_ok=True)
    env_overrides = EnvOverrides(
        hf_home=str(dest_dir),
        torch_home=str(dest_dir),
        modelscope_cache=str(dest_dir),
    )

    with env_overrides.activate():
        model_ids = collect_model_ids(scan_dir)
        if not model_ids:
            logger.error("No model id discovered from tests scan.")
            return 1

        logger.info("Discovered %d model ids.", len(model_ids))
        if args.dry_run:
            results = [PrefetchResult(model_id=mid, source="dry-run", success=True) for mid in model_ids]
            for r in results:
                logger.info("[DRY-RUN] %s", r.model_id)
        else:
            prefetchers = _build_prefetchers()
            if not prefetchers:
                logger.error("Neither transformers nor modelscope is installed.")
                return 1
            results = _prefetch_all(model_ids, prefetchers)
            for r in results:
                if r.success:
                    logger.info("[OK] %s (%s)", r.model_id, r.source)
                else:
                    logger.error("[FAIL] %s: %s", r.model_id, r.error)

    manifest_path = _write_manifest(dest_dir, scan_dir, results)
    logger.info("manifest written: %s", manifest_path)

    failures = [item for item in results if not item.success]
    return 1 if failures else 0


if __name__ == "__main__":
    raise SystemExit(main())