"""Shared pytest subprocess helpers for CI gate and test_map tooling."""

from __future__ import annotations

import os
import subprocess
import sys
from typing import TYPE_CHECKING, Final

if TYPE_CHECKING:
    from collections.abc import Sequence

from scripts.helpers._config import ConfigError
from scripts.helpers._paths import REPO_ROOT

PYTEST_IGNORE_ADDOPTS: Final[list[str]] = ["-o", "addopts="]

_DEFAULT_RUN_ARGS: Final[tuple[str, ...]] = (
    "-vv",
    "--tb=short",
    "--durations=20",
    "--disable-warnings",
)


def _parse_collect_only_node_ids(stdout: str) -> tuple[str, ...]:
    node_ids: list[str] = []
    for line in stdout.splitlines():
        stripped = line.strip()
        if "::" in stripped and stripped.startswith("tests/"):
            node_ids.append(stripped)
    return tuple(node_ids)


def collect_test_node_ids(targets: Sequence[str], *, marker: str) -> tuple[str, ...]:
    """Collect pytest node ids matching *marker* from collect-only stdout."""
    if not targets:
        return ()

    cmd = [
        sys.executable,
        "-m",
        "pytest",
        *targets,
        *PYTEST_IGNORE_ADDOPTS,
        "-m",
        marker,
        "--collect-only",
        "-q",
        "--no-header",
    ]
    proc = subprocess.run(
        cmd,
        cwd=REPO_ROOT,
        capture_output=True,
        text=True,
        check=False,
    )
    if proc.returncode not in (0, 5):
        detail = (proc.stderr or proc.stdout or "").strip()
        raise ConfigError(f"pytest collect-only failed (exit {proc.returncode})" + (f": {detail}" if detail else ""))

    return _parse_collect_only_node_ids(proc.stdout)


def _run_collect_only(targets: Sequence[str], *, marker: str) -> subprocess.CompletedProcess[str]:
    cmd = [
        sys.executable,
        "-m",
        "pytest",
        *targets,
        *PYTEST_IGNORE_ADDOPTS,
        "-m",
        marker,
        "--collect-only",
        "-q",
        "--no-header",
    ]
    return subprocess.run(
        cmd,
        cwd=REPO_ROOT,
        capture_output=True,
        text=True,
        check=False,
    )


def _collect_test_node_ids_lenient(targets: Sequence[str], *, marker: str) -> tuple[str, ...]:
    """Like collect_test_node_ids but returns () on exit 4 (missing node ids)."""
    if not targets:
        return ()

    proc = _run_collect_only(targets, marker=marker)
    if proc.returncode in (0, 4, 5):
        return _parse_collect_only_node_ids(proc.stdout)

    detail = (proc.stderr or proc.stdout or "").strip()
    raise ConfigError(f"pytest collect-only failed (exit {proc.returncode})" + (f": {detail}" if detail else ""))


def _stderr_reports_all_targets_missing(proc: subprocess.CompletedProcess[str], targets: Sequence[str]) -> bool:
    """True when pytest exit 4 only because every target node id was not found."""
    if proc.returncode != 4:
        return False
    combined = f"{proc.stderr or ''}\n{proc.stdout or ''}"
    if "not found:" not in combined:
        return False
    return all(target in combined for target in targets)


def filter_collectable_node_ids(targets: Sequence[str], *, marker: str) -> tuple[str, ...]:
    """Return collectable node ids; drop stale ids instead of failing the batch."""
    if not targets:
        return ()
    if not all("::" in target for target in targets):
        return collect_test_node_ids(targets, marker=marker)

    proc = _run_collect_only(targets, marker=marker)
    if proc.returncode not in (0, 4, 5):
        detail = (proc.stderr or proc.stdout or "").strip()
        raise ConfigError(f"pytest collect-only failed (exit {proc.returncode})" + (f": {detail}" if detail else ""))

    batch_ids = _parse_collect_only_node_ids(proc.stdout)
    if batch_ids:
        target_set = frozenset(targets)
        return tuple(node_id for node_id in batch_ids if node_id in target_set)

    if _stderr_reports_all_targets_missing(proc, targets):
        return ()

    return tuple(target for target in targets if target in _collect_test_node_ids_lenient((target,), marker=marker))


def count_collected_tests(targets: Sequence[str], *, marker: str) -> int:
    """Collect pytest node ids matching *marker* and return the count."""
    return len(collect_test_node_ids(targets, marker=marker))


def xdist_worker_args(collected_count: int) -> list[str]:
    """Return pytest-xdist flags sized to the collected test count."""
    if collected_count == 0:
        return []
    worker_count = min(os.cpu_count() or 1, max(collected_count, 1))
    return ["-n", str(worker_count), "--dist", "worksteal"]


def build_pytest_cmd(
    python: str,
    targets: Sequence[str],
    *,
    marker: str,
    collected_count: int,
    extra_args: Sequence[str] = (),
) -> list[str]:
    """Assemble a pytest command with explicit marker and collect-first xdist sizing."""
    cmd = [
        python,
        "-m",
        "pytest",
        *targets,
        *PYTEST_IGNORE_ADDOPTS,
        "-m",
        marker,
        *xdist_worker_args(collected_count),
        *_DEFAULT_RUN_ARGS,
        *extra_args,
    ]
    return cmd