"""Load and validate test_map JSON and CI gate baseline."""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING
from scripts.helpers._config import Config, ConfigError, format_expected_got
from scripts.helpers.ci_gate.diff import is_git_ancestor
from scripts.helpers.ci_gate.models import Baseline
from scripts.helpers.ci_gate.policy import load_gate_policy
from scripts.helpers.common.coverage_config import product_roots
from scripts.helpers.common.test_map_config import resolve_test_map_path
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__)
TestMap = dict[str, dict[str, list[str]]]
@dataclass(frozen=True, slots=True)
class TestMapFreshness:
block_message: str | None = None
warn_message: str | None = None
def _validate_test_node_key(test_node: str) -> None:
if ".." in test_node or test_node.startswith("/"):
raise ConfigError(f"test_map: invalid test node key: {test_node!r}")
if not test_node.startswith("tests/") or "::" not in test_node:
raise ConfigError(f"test_map: map key must be a pytest node id under tests/: {test_node!r}")
def _validate_canonical_symbol(symbol: str, *, test_node: str, source_file: str) -> None:
if "." in symbol and "::" not in symbol and "@" not in symbol:
raise ConfigError(
f"test_map: symbol under {test_node!r} -> {source_file!r}: "
f"symbol must use canonical Class::method form, got {symbol!r}"
)
def _validate_source_entry_types(test_node: str, source_file: object, symbols: object) -> None:
if not isinstance(source_file, str):
field = f"source file under {test_node!r}"
raise ConfigError(f"test_map: {format_expected_got(field, 'a string', source_file)}")
if not isinstance(symbols, list):
field = f"symbols under {test_node!r} -> {source_file!r}"
raise ConfigError(f"test_map: {format_expected_got(field, 'a list', symbols)}")
def _validate_map_payload(
inner: dict[str, object],
roots: tuple[str, ...],
) -> TestMap:
validated: TestMap = {}
for test_node, sources in inner.items():
if not isinstance(test_node, str):
raise ConfigError(f"test_map: {format_expected_got('map key', 'a string', test_node)}")
_validate_test_node_key(test_node)
if not isinstance(sources, dict):
raise ConfigError(f"test_map: {format_expected_got(f'value for {test_node!r}', 'an object', sources)}")
source_map: dict[str, list[str]] = {}
for source_file, symbols in sources.items():
_validate_source_entry_types(test_node, source_file, symbols)
assert isinstance(source_file, str)
assert isinstance(symbols, list)
if not any(source_file.startswith(prefix) for prefix in roots):
raise ConfigError(
f"test_map: source file must start with a product root ({', '.join(roots)}): {source_file!r}"
)
if not all(isinstance(symbol, str) for symbol in symbols):
field = f"symbols under {test_node!r} -> {source_file!r}"
raise ConfigError(f"test_map: {format_expected_got(field, 'strings', symbols)}")
for symbol in symbols:
_validate_canonical_symbol(symbol, test_node=test_node, source_file=source_file)
source_map[source_file] = list(symbols)
if source_map:
validated[test_node] = source_map
return validated
def parse_test_map_map_object(
inner: object,
*,
roots: tuple[str, ...],
) -> TestMap:
"""Validate and return a node-oriented ``map`` object from test_map JSON."""
if not isinstance(inner, dict):
raise ConfigError(f"test_map: {format_expected_got('map', 'an object', inner)}")
if inner and not any(isinstance(key, str) and key.startswith("tests/") and "::" in key for key in inner):
sample = next(iter(inner))
raise ConfigError(f"test_map: map must be keyed by pytest node ids (tests/...py::test_name); got {sample!r}")
return _validate_map_payload(inner, roots)
def _parse_test_map_payload(
data: dict[str, object],
resolved_roots: tuple[str, ...],
) -> tuple[TestMap, str | None]:
schema_version = data.get("schema_version")
if schema_version not in (1, None):
raise ConfigError(f"test_map: unsupported schema_version {schema_version!r}")
built_from_commit = data.get("built_from_commit")
if built_from_commit is not None and not isinstance(built_from_commit, str):
raise ConfigError(f"test_map: {format_expected_got('built_from_commit', 'a string', built_from_commit)}")
inner = data.get("map")
return parse_test_map_map_object(inner, roots=resolved_roots), built_from_commit
def load_test_map(
cfg: Config,
*,
roots: tuple[str, ...] | None = None,
) -> TestMap:
mapping, _commit = load_test_map_with_commit(cfg, roots=roots)
return mapping
def load_test_map_with_commit(
cfg: Config,
*,
roots: tuple[str, ...] | None = None,
) -> tuple[TestMap, str | None]:
resolved_roots = roots if roots is not None else product_roots()
map_path = resolve_test_map_path(cfg, must_exist=True)
try:
data = json.loads(map_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise ConfigError(f"test_map: invalid JSON at {map_path}: {exc}") from exc
if not isinstance(data, dict):
raise ConfigError(f"test_map: {format_expected_got('root', 'an object', data)}")
return _parse_test_map_payload(data, resolved_roots)
def assess_test_map_freshness(
repo_root: Path,
built_from_commit: str | None,
merge_base: str,
) -> TestMapFreshness:
"""Return block/warn messages for stale test_map relative to merge-base."""
if not built_from_commit:
return TestMapFreshness(
block_message="test_map: built_from_commit is required; rebuild test_map via nightly or build_test_map"
)
if is_git_ancestor(repo_root, merge_base, built_from_commit):
return TestMapFreshness()
if is_git_ancestor(repo_root, built_from_commit, merge_base):
return TestMapFreshness(
warn_message=(
"test_map: built_from_commit "
f"{built_from_commit[:12]} is behind merge-base {merge_base[:12]}; continuing with stale map"
)
)
return TestMapFreshness(
block_message=(
"test_map: stale built_from_commit "
f"{built_from_commit[:12]} is not an ancestor of merge-base {merge_base[:12]}"
)
)
def validate_test_map_freshness(
repo_root: Path,
built_from_commit: str | None,
merge_base: str,
) -> None:
"""Raise ConfigError when test_map freshness policy blocks the gate."""
freshness = assess_test_map_freshness(repo_root, built_from_commit, merge_base)
if freshness.block_message:
raise ConfigError(freshness.block_message)
if freshness.warn_message:
logger.warning("%s", freshness.warn_message)
def load_baseline(repo_root: Path, cfg: Config) -> tuple[Baseline, str | None]:
"""Load full gate baseline: test_map + gate policy."""
policy = load_gate_policy(repo_root)
test_map, built_from_commit = load_test_map_with_commit(cfg, roots=policy.roots)
baseline = Baseline(test_map=test_map, policy=policy)
return baseline, built_from_commit
def is_product_source(path: str, prefixes: tuple[str, ...]) -> bool:
return any(path.startswith(prefix) for prefix in prefixes)