"""Centralised env-var configuration for all helpers.

Loaded via pydantic-settings from process environment. Shell entry scripts
(run_ci_gate.sh, run_nightly.sh, etc.) set env-vars with defaults.
"""

from __future__ import annotations

from typing import Final

from pydantic import Field, ValidationError, ValidationInfo, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

_THRESHOLD_MAX: Final = 100.0
_BOOL_TRUE = frozenset({"1", "true", "yes", "on"})
_BOOL_FALSE = frozenset({"0", "false", "no", "off"})


class ConfigError(Exception):
    """Raised when a required config key is missing or invalid."""


_FIELD_ENV_KEYS: Final = {
    "test_map_path": "MSMODELING_TEST_MAP_PATH",
    "base_branch": "MSMODELING_TEST_BASE_BRANCH",
    "line_threshold": "MSMODELING_TEST_LINE_THRESHOLD",
    "branch_threshold": "MSMODELING_TEST_BRANCH_THRESHOLD",
    "benchmark_parallel": "MSMODELING_BENCHMARK_PARALLEL",
    "feishu_webhook_url": "FEISHU_WEBHOOK_URL",
    "msmodeling_cache": "MSMODELING_CACHE",
    "weights_prune": "MSMODELING_TEST_WEIGHTS_PRUNE",
    "gitcode_owner": "GITCODE_OWNER",
    "gitcode_repo": "GITCODE_REPO",
    "gitcode_pr_number": "GITCODE_PR_NUMBER",
    "gitcode_pat": "GITCODE_PAT",
}


def format_expected_got(field: str, expected: str, got: object) -> str:
    return f"Expected {field!r} to be {expected}. Got {got!r} instead."


def _format_validation_error(exc: ValidationError) -> str:
    parts: list[str] = []
    for err in exc.errors():
        loc = err.get("loc", ())
        field_name = str(loc[-1]) if loc else "config"
        env_key = _FIELD_ENV_KEYS.get(field_name, field_name)
        msg = err.get("msg", "invalid value")
        if isinstance(msg, str) and msg.startswith("Value error, "):
            msg = msg.removeprefix("Value error, ")
        parts.append(f"{env_key}: {msg}")
    return "\n".join(parts)


def _parse_bool_env(value: object, *, default: bool, field: str) -> bool:
    if value is None:
        return default
    if isinstance(value, bool):
        return value
    if not isinstance(value, str):
        raise ValueError(format_expected_got(field, "a boolean", value))
    raw = value.strip().lower()
    if not raw:
        raise ValueError(format_expected_got(field, "a boolean", value))
    if raw in _BOOL_TRUE:
        return True
    if raw in _BOOL_FALSE:
        return False
    raise ValueError(format_expected_got(field, "a boolean", value))


def _parse_float_env(value: object, *, default: float, field: str) -> float:
    if value is None:
        return default
    if isinstance(value, (int, float)) and not isinstance(value, bool):
        return float(value)
    if isinstance(value, str):
        raw = value.strip()
        if not raw:
            raise ValueError(format_expected_got(field, "a number", value))
        try:
            return float(raw)
        except ValueError as exc:
            raise ValueError(format_expected_got(field, "a number", raw)) from exc
    raise ValueError(format_expected_got(field, "a number", value))


class Config(BaseSettings):
    """Application config read once at CLI startup and passed through helpers."""

    model_config = SettingsConfigDict(extra="ignore", populate_by_name=True, frozen=True)

    test_map_path: str | None = Field(default=None, validation_alias="MSMODELING_TEST_MAP_PATH")
    base_branch: str = Field(default="master", validation_alias="MSMODELING_TEST_BASE_BRANCH")
    line_threshold: float = Field(default=60.0, validation_alias="MSMODELING_TEST_LINE_THRESHOLD")
    branch_threshold: float = Field(default=40.0, validation_alias="MSMODELING_TEST_BRANCH_THRESHOLD")
    benchmark_parallel: bool = Field(default=False, validation_alias="MSMODELING_BENCHMARK_PARALLEL")
    feishu_webhook_url: str = Field(default="", validation_alias="FEISHU_WEBHOOK_URL")
    msmodeling_cache: str = Field(default=".msmodeling_cache", validation_alias="MSMODELING_CACHE")
    weights_prune: bool = Field(default=False, validation_alias="MSMODELING_TEST_WEIGHTS_PRUNE")
    gitcode_owner: str = Field(default="", validation_alias="GITCODE_OWNER")
    gitcode_repo: str = Field(default="", validation_alias="GITCODE_REPO")
    gitcode_pr_number: int | None = Field(default=None, validation_alias="GITCODE_PR_NUMBER")
    gitcode_pat: str = Field(default="", validation_alias="GITCODE_PAT")

    @field_validator("base_branch", "msmodeling_cache", mode="before")
    @classmethod
    def _strip_path_strings(cls, value: object) -> object:
        if isinstance(value, str):
            return value.strip()
        return value

    @field_validator("gitcode_owner", "gitcode_repo", "gitcode_pat", mode="before")
    @classmethod
    def _strip_gitcode_strings(cls, value: object) -> object:
        if isinstance(value, str):
            return value.strip()
        return value

    @field_validator("gitcode_pr_number", mode="before")
    @classmethod
    def _parse_gitcode_pr_number(cls, value: object) -> object:
        if value is None or value == "":
            return None
        if isinstance(value, int):
            return value
        if isinstance(value, str):
            raw = value.strip()
            if not raw:
                return None
            return int(raw)
        raise ValueError("must be an integer")

    @field_validator("test_map_path", mode="before")
    @classmethod
    def _empty_test_map_path_is_none(cls, value: object) -> object:
        if isinstance(value, str) and value.strip() == "":
            return None
        return value

    @field_validator("feishu_webhook_url", mode="before")
    @classmethod
    def _strip_feishu_webhook(cls, value: object) -> object:
        if isinstance(value, str):
            return value.strip()
        return value

    @field_validator("line_threshold", mode="before")
    @classmethod
    def _parse_line_threshold(cls, value: object, info: ValidationInfo) -> float:
        field_name = info.field_name or "line_threshold"
        return _parse_float_env(value, default=60.0, field=_FIELD_ENV_KEYS[field_name])

    @field_validator("branch_threshold", mode="before")
    @classmethod
    def _parse_branch_threshold(cls, value: object, info: ValidationInfo) -> float:
        field_name = info.field_name or "branch_threshold"
        return _parse_float_env(value, default=40.0, field=_FIELD_ENV_KEYS[field_name])

    @field_validator("benchmark_parallel", mode="before")
    @classmethod
    def _parse_benchmark_parallel(cls, value: object, info: ValidationInfo) -> bool:
        field_name = info.field_name or "benchmark_parallel"
        return _parse_bool_env(value, default=False, field=_FIELD_ENV_KEYS[field_name])

    @field_validator("weights_prune", mode="before")
    @classmethod
    def _parse_weights_prune(cls, value: object, info: ValidationInfo) -> bool:
        field_name = info.field_name or "weights_prune"
        return _parse_bool_env(value, default=False, field=_FIELD_ENV_KEYS[field_name])

    @field_validator("line_threshold", "branch_threshold")
    @classmethod
    def _validate_threshold(cls, value: float, info: ValidationInfo) -> float:
        if not (0 <= value <= _THRESHOLD_MAX):
            raise ValueError(f"must be in [0, {_THRESHOLD_MAX:g}], got {value}")
        return value

    @classmethod
    def from_env(cls) -> Config:
        try:
            return cls()
        except ValidationError as exc:
            raise ConfigError(_format_validation_error(exc)) from exc