"""Shared pytest subprocess helpers for CI gate and test_map tooling."""
from __future__ import annotations
import logging
import os
import re
import subprocess
import sys
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
from collections.abc import Sequence
from scripts.helpers._config import ConfigError
from scripts.helpers._paths import REPO_ROOT
logger = logging.getLogger(__name__)
_NOT_FOUND_RE = re.compile(r"^ERROR: not found: ([^\n\r]+)", re.MULTILINE)
PYTEST_IGNORE_ADDOPTS: Final[list[str]] = ["-o", "addopts="]
_DEFAULT_RUN_ARGS: Final[tuple[str, ...]] = (
"-vv",
"--tb=short",
"--durations=20",
"--disable-warnings",
)
def _parse_collect_only_node_ids(stdout: str) -> tuple[str, ...]:
node_ids: list[str] = []
for line in stdout.splitlines():
stripped = line.strip()
if "::" in stripped and stripped.startswith("tests/"):
node_ids.append(stripped)
return tuple(node_ids)
def _normalize_node_id(node_id: str) -> str:
"""Map pytest node ids to repo-relative ``tests/...`` form for comparison."""
normalized = node_id.strip()
tests_idx = normalized.find("tests/")
if tests_idx >= 0:
return normalized[tests_idx:]
return normalized
def _parse_not_found_node_ids(stderr: str) -> frozenset[str]:
return frozenset(_normalize_node_id(node_id) for node_id in _NOT_FOUND_RE.findall(stderr))
def collect_test_node_ids(targets: Sequence[str], *, marker: str | None) -> tuple[str, ...]:
"""Collect pytest node ids matching *marker* from collect-only stdout."""
return _collect_with_marker(targets, marker=marker)
def collect_all_test_node_ids(targets: Sequence[str]) -> tuple[str, ...]:
"""Collect pytest node ids from *targets* without applying a marker filter."""
return _collect_with_marker(targets, marker=None)
def _collect_with_marker(targets: Sequence[str], *, marker: str | None) -> tuple[str, ...]:
if not targets:
return ()
cmd = [
sys.executable,
"-m",
"pytest",
*targets,
*PYTEST_IGNORE_ADDOPTS,
"--collect-only",
"-q",
"--no-header",
]
if marker is not None:
cmd.extend(["-m", marker])
proc = subprocess.run(
cmd,
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
)
if proc.returncode not in (0, 5):
detail = (proc.stderr or proc.stdout or "").strip()
raise ConfigError(f"pytest collect-only failed (exit {proc.returncode})" + (f": {detail}" if detail else ""))
return _parse_collect_only_node_ids(proc.stdout)
def _run_collect_only(targets: Sequence[str], *, marker: str | None) -> subprocess.CompletedProcess[str]:
cmd = [
sys.executable,
"-m",
"pytest",
*targets,
*PYTEST_IGNORE_ADDOPTS,
"--collect-only",
"-q",
"--no-header",
]
if marker is not None:
cmd.extend(["-m", marker])
return subprocess.run(
cmd,
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
)
def filter_collectable_node_ids(targets: Sequence[str], *, marker: str | None) -> tuple[str, ...]:
"""Return collectable node ids; drop stale ids instead of failing the batch."""
if not targets:
return ()
if not all("::" in target for target in targets):
if marker is None:
return collect_all_test_node_ids(targets)
return collect_test_node_ids(targets, marker=marker)
proc = _run_collect_only(targets, marker=marker)
if proc.returncode not in (0, 4, 5):
detail = (proc.stderr or proc.stdout or "").strip()
raise ConfigError(f"pytest collect-only failed (exit {proc.returncode})" + (f": {detail}" if detail else ""))
batch_ids = _parse_collect_only_node_ids(proc.stdout)
if batch_ids:
target_set = frozenset(targets)
return tuple(node_id for node_id in batch_ids if node_id in target_set)
not_found = _parse_not_found_node_ids(proc.stderr or "")
if not_found:
remaining = tuple(target for target in targets if _normalize_node_id(target) not in not_found)
dropped = len(targets) - len(remaining)
if dropped:
sample = ", ".join(sorted(not_found)[:3])
logger.info("Dropped %d stale pytest node id(s); sample: %s", dropped, sample)
return remaining
if proc.returncode == 0:
return tuple(targets)
if proc.returncode == 5:
return ()
detail = (proc.stderr or proc.stdout or "").strip()
raise ConfigError("pytest collect-only exit 4 with unparseable stderr" + (f": {detail[:500]}" if detail else ""))
def count_collected_tests(targets: Sequence[str], *, marker: str) -> int:
"""Collect pytest node ids matching *marker* and return the count."""
return len(collect_test_node_ids(targets, marker=marker))
def xdist_worker_args(collected_count: int) -> list[str]:
"""Return pytest-xdist flags sized to the collected test count."""
if collected_count == 0:
return []
worker_count = min(os.cpu_count() or 1, max(collected_count, 1))
return ["-n", str(worker_count), "--dist", "worksteal"]
def build_pytest_cmd(
python: str,
targets: Sequence[str],
*,
marker: str | None,
collected_count: int,
extra_args: Sequence[str] = (),
) -> list[str]:
"""Assemble a pytest command with optional marker and collect-first xdist sizing."""
cmd = [
python,
"-m",
"pytest",
*targets,
*PYTEST_IGNORE_ADDOPTS,
]
if marker is not None:
cmd.extend(["-m", marker])
cmd.extend([*xdist_worker_args(collected_count), *_DEFAULT_RUN_ARGS, *extra_args])
return cmd