#!/usr/bin/env python3
"""Build test_map JSON from Coverage.py dynamic contexts (pytest --cov-context=test)."""

from __future__ import annotations

import json
import logging
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Final

from scripts.helpers._paths import REPO_ROOT
from scripts.helpers.common.ast_utils import iter_qualified_definition_spans, symbol_for_line
from scripts.helpers.common.coverage_config import product_roots
from scripts.helpers.common.coverage_omit import is_coverage_omitted_source
from scripts.helpers.common.pytest_runner import PYTEST_IGNORE_ADDOPTS

logger = logging.getLogger(__name__)

UNCLASSIFIED_SYMBOL: Final = "*"


def _relative_repo_key(abs_path: str, roots: tuple[str, ...]) -> str | None:
    try:
        rel = Path(abs_path).resolve().relative_to(REPO_ROOT)
    except ValueError:
        return None
    key = rel.as_posix()
    if key.startswith(roots):
        return key
    return None


def _normalize_pytest_context(ctx: str) -> str:
    """Strip pytest-cov phase suffix ``|run``, ``|setup``, ``|teardown``."""
    return ctx.split("|", 1)[0].strip() if ctx else ""


def _collect_allowed_node_ids(
    marker_expr: str,
    pytest_args: list[str] | None = None,
) -> frozenset[str]:
    """Return node ids from smoke/regression directories matching marker_expr.

    Uses ``pytest --collect-only -q --no-header`` for stable machine-readable
    output. Strips parameterised suffixes ``[param]`` to match coverage context
    base node ids.
    """
    if pytest_args is None:
        pytest_args = [
            sys.executable,
            "-m",
            "pytest",
            *PYTEST_IGNORE_ADDOPTS,
            str(REPO_ROOT / "tests" / "smoke"),
            str(REPO_ROOT / "tests" / "regression"),
            "-m",
            marker_expr,
            "--collect-only",
            "-q",
            "--no-header",
        ]
    proc = subprocess.run(
        pytest_args,
        cwd=REPO_ROOT,
        capture_output=True,
        text=True,
        check=False,
    )
    if proc.returncode != 0:
        logger.error("pytest collect-only failed:\n%s", proc.stderr or proc.stdout)
        raise SystemExit(proc.returncode)

    node_ids: set[str] = set()
    for line in proc.stdout.splitlines():
        stripped = line.strip()
        if "::" in stripped and stripped.startswith("tests/"):
            base_id = stripped.split("[", 1)[0]
            node_ids.add(base_id)
    return frozenset(node_ids)


def collect_from_coverage(
    allowed_node_ids: frozenset[str],
    *,
    coverage_path: Path | None = None,
    roots: tuple[str, ...] | None = None,
) -> dict[str, dict[str, list[str]]]:
    """Read .coverage data, resolve symbols, return {rel_path: {symbol: [test_ids]}}.

    Returns empty dict if coverage data is missing or unreadable — logs warning
    to stderr so caller can distinguish "no data" from "empty map".
    """
    from coverage.data import CoverageData
    from coverage.misc import CoverageException

    if coverage_path is None:
        coverage_path = REPO_ROOT / ".coverage"
    resolved_roots = roots if roots is not None else product_roots()

    if not coverage_path.is_file():
        logger.warning("Coverage data not found: %s", coverage_path)
        return {}

    data = CoverageData(str(coverage_path))
    try:
        data.read()
    except CoverageException as exc:
        logger.warning("Failed to read coverage data: %s", exc)
        return {}
    except OSError as exc:
        logger.warning("Coverage data file error: %s", exc)
        return {}

    by_file: dict[str, dict[str, dict[str, int]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

    for measured in data.measured_files():
        key = _relative_repo_key(measured, resolved_roots)
        if key is None:
            continue
        if is_coverage_omitted_source(key, resolved_roots):
            continue
        ctxmap = data.contexts_by_lineno(measured)
        if not ctxmap:
            continue

        spans = iter_qualified_definition_spans(Path(measured).resolve())

        for line_no, ctxs in ctxmap.items():
            sym = symbol_for_line(spans, line_no)
            bucket = sym if sym is not None else UNCLASSIFIED_SYMBOL
            for ctx in ctxs:
                nid = _normalize_pytest_context(ctx)
                if nid and nid in allowed_node_ids:
                    by_file[key][bucket][nid] += 1

    result: dict[str, dict[str, list[str]]] = {}
    for fp, syms in sorted(by_file.items()):
        filtered: dict[str, list[str]] = {}
        for sym, test_lines in sorted(syms.items()):
            kept = sorted(nid for nid, count in test_lines.items() if count)
            if kept:
                filtered[sym] = kept
        if filtered:
            result[fp] = filtered

    return result


def _prune_missing_source_keys(
    mapping: dict[str, dict[str, list[str]]],
) -> dict[str, dict[str, list[str]]]:
    """Drop product paths that no longer exist on disk."""
    return {source_path: symbols for source_path, symbols in mapping.items() if (REPO_ROOT / source_path).is_file()}


def write_test_map(output_path: Path, mapping: dict[str, dict[str, list[str]]]) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    payload = {"schema_version": 1, "map": mapping}
    output_path.write_text(
        json.dumps(payload, indent=2, ensure_ascii=False) + "\n",
        encoding="utf-8",
    )
    symbol_count = sum(len(syms) for syms in mapping.values())
    logger.info(
        "test_map written: %s (%d source files, %d symbols)",
        output_path,
        len(mapping),
        symbol_count,
    )


def build_test_map(
    output_path: Path,
    *,
    marker_expr: str,
    coverage_path: Path | None = None,
    roots: tuple[str, ...] | None = None,
    allowed_node_ids: frozenset[str] | None = None,
) -> None:
    allowed = allowed_node_ids if allowed_node_ids is not None else _collect_allowed_node_ids(marker_expr)
    mapping = collect_from_coverage(allowed, coverage_path=coverage_path, roots=roots)
    mapping = _prune_missing_source_keys(mapping)
    write_test_map(output_path, mapping)


def collect_test_map(
    *,
    marker_expr: str,
    coverage_path: Path | None = None,
    roots: tuple[str, ...] | None = None,
    allowed_node_ids: frozenset[str] | None = None,
) -> dict[str, dict[str, list[str]]]:
    """Return test_map dict in memory — no file I/O.

    Same logic as build_test_map but returns the mapping directly.
    """
    allowed = allowed_node_ids if allowed_node_ids is not None else _collect_allowed_node_ids(marker_expr)
    mapping = collect_from_coverage(allowed, coverage_path=coverage_path, roots=roots)
    return _prune_missing_source_keys(mapping)


def detect_redundant_cases(
    mapping: dict[str, dict[str, list[str]]],
    *,
    jaccard_threshold: float = 0.85,
    max_per_symbol: int = 5,
) -> list[dict[str, object]]:
    """Return redundancy warnings from a test_map.

    Two checks:
    1. Symbols covered by more than *max_per_symbol* test cases.
    2. Pairs of test cases whose covered-symbol sets have Jaccard similarity
       >= *jaccard_threshold*.

    Returns a list of warning dicts suitable for nightly report inclusion.
    """
    warnings: list[dict[str, object]] = []

    test_to_symbols: dict[str, set[str]] = defaultdict(set)
    symbol_to_tests: dict[str, set[str]] = defaultdict(set)
    for src_file, symbols in mapping.items():
        for sym, test_ids in symbols.items():
            if sym == UNCLASSIFIED_SYMBOL:
                continue
            qualified = f"{src_file}::{sym}"
            for tid in test_ids:
                test_to_symbols[tid].add(qualified)
                symbol_to_tests[qualified].add(tid)

    for src_file, symbols in mapping.items():
        for sym, test_ids in symbols.items():
            if sym == UNCLASSIFIED_SYMBOL:
                continue
            if len(test_ids) > max_per_symbol:
                warnings.append(
                    {
                        "type": "over_covered_symbol",
                        "symbol": f"{src_file}::{sym}",
                        "test_count": len(test_ids),
                        "threshold": max_per_symbol,
                        "tests": sorted(test_ids),
                    }
                )

    compared_pairs: set[tuple[str, str]] = set()
    for tests in symbol_to_tests.values():
        test_list = sorted(tests)
        for i, a_id in enumerate(test_list):
            for b_id in test_list[i + 1 :]:
                pair = (a_id, b_id)
                if pair in compared_pairs:
                    continue
                compared_pairs.add(pair)
                a_syms = test_to_symbols[a_id]
                b_syms = test_to_symbols[b_id]
                intersection = a_syms & b_syms
                union = a_syms | b_syms
                if not union:
                    continue
                jaccard = len(intersection) / len(union)
                if jaccard >= jaccard_threshold:
                    warnings.append(
                        {
                            "type": "redundant_pair",
                            "test_a": a_id,
                            "test_b": b_id,
                            "jaccard": round(jaccard, 3),
                            "shared_symbols": sorted(intersection),
                        }
                    )

    return warnings