"""Maintain authoritative test_map against a target branch HEAD.
CLI entry for ``run_test_map_sync.sh``. Reads and writes only
``MSMODELING_TEST_MAP_PATH``; OBS upload/download stays outside the repo.
"""
from __future__ import annotations
import argparse
import logging
import os
import shlex
import signal
import subprocess
import sys
import time
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
from pathlib import Path
from scripts.helpers._config import Config, ConfigError, format_expected_got
from scripts.helpers._paths import REPO_ROOT
from scripts.helpers.ci_gate.diff import (
ephemeral_target_checkout,
fetch_changed_paths,
is_git_ancestor,
resolve_target_head,
)
from scripts.helpers.ci_gate.gate_policy import load_gate_policy
from scripts.helpers.common._logging import log_env_audit, setup_logger
from scripts.helpers.common.build_test_map import (
TestMap,
collect_allowed_node_ids,
collect_test_map,
prune_missing_source_keys,
write_test_map,
)
from scripts.helpers.common.coverage_config import cov_pytest_args, pytest_xdist_args
from scripts.helpers.common.test_map_config import (
TEST_MAP_COLLECTION_MARKER,
TEST_MAP_EXECUTION_MARKER,
resolve_test_map_path,
)
from scripts.helpers.common.test_map_loader import load_test_map_with_commit
_DEFAULT_SYNC_INTERVAL_SECONDS: Final = 60.0
_shutdown_flag: list[bool] = [False]
def _parse_sync_interval() -> float:
raw = os.environ.get("MSMODELING_TEST_MAP_SYNC_INTERVAL", "").strip()
if not raw:
return _DEFAULT_SYNC_INTERVAL_SECONDS
try:
interval = float(raw)
except ValueError as exc:
raise ConfigError(format_expected_got("MSMODELING_TEST_MAP_SYNC_INTERVAL", "a number", raw)) from exc
if interval <= 0:
raise ConfigError(format_expected_got("MSMODELING_TEST_MAP_SYNC_INTERVAL", "positive", interval))
return interval
def resolve_target_branch(*, cli_target: str | None, cfg: Config) -> str:
if cli_target:
return cli_target
env_target = os.environ.get("MSMODELING_TEST_MAP_TARGET_BRANCH", "").strip()
if env_target:
return env_target
return cfg.base_branch
def _test_file_for_node(test_node: str) -> str:
return test_node.split("::", 1)[0]
def apply_incremental_test_map_update(
existing_map: TestMap,
fresh_map: TestMap,
touched_paths: frozenset[str],
) -> TestMap:
"""Merge *fresh_map* into *existing_map* for git-touched product or test paths."""
all_test_nodes = set(existing_map) | set(fresh_map)
merged: TestMap = {}
for test_node in all_test_nodes:
test_file = _test_file_for_node(test_node)
if test_file in touched_paths:
fresh_sources = fresh_map.get(test_node)
if fresh_sources:
merged[test_node] = {src: list(syms) for src, syms in fresh_sources.items()}
continue
sources: dict[str, list[str]] = {}
for src_path, symbols in existing_map.get(test_node, {}).items():
if src_path not in touched_paths:
sources[src_path] = list(symbols)
for src_path, symbols in fresh_map.get(test_node, {}).items():
if src_path in touched_paths:
sources[src_path] = list(symbols)
if sources:
merged[test_node] = sources
return prune_missing_source_keys(merged)
def build_test_map_pytest_cmd(python_exe: str) -> list[str]:
return [
python_exe,
"-m",
"pytest",
"tests/smoke/",
"tests/regression/",
"-m",
TEST_MAP_EXECUTION_MARKER,
*pytest_xdist_args(),
*cov_pytest_args(cov_context=True),
"-q",
"--no-header",
"--tb=short",
"--disable-warnings",
]
def run_test_map_pytest(repo_root: Path, python_exe: str) -> int:
cmd = build_test_map_pytest_cmd(python_exe)
env = os.environ.copy()
env["PYTHONPATH"] = str(repo_root)
logging.getLogger(__name__).info("Running pytest: %s", shlex.join(cmd))
proc = subprocess.run(cmd, cwd=repo_root, env=env, check=False)
return proc.returncode
def can_incremental_sync(repo_root: Path, built_from_commit: str, target_head: str) -> bool:
"""Return True when an incremental merge from *built_from_commit* to *target_head* is safe."""
if built_from_commit == target_head:
return True
return is_git_ancestor(repo_root, built_from_commit, target_head)
def _try_load_existing_map(cfg: Config) -> tuple[TestMap | None, str | None]:
map_path = resolve_test_map_path(cfg, must_exist=False)
if not map_path.is_file():
return None, None
try:
return load_test_map_with_commit(cfg)
except ConfigError:
return None, None
def _collect_fresh_map() -> TestMap:
gate_policy = load_gate_policy(REPO_ROOT)
allowed_node_ids = collect_allowed_node_ids(TEST_MAP_COLLECTION_MARKER)
return collect_test_map(
marker_expr=TEST_MAP_COLLECTION_MARKER,
roots=gate_policy.roots,
allowed_node_ids=allowed_node_ids,
)
def _run_pytest_and_collect_fresh_map(repo_root: Path, python_exe: str) -> TestMap | None:
pytest_exit = run_test_map_pytest(repo_root, python_exe)
if pytest_exit != 0:
return None
return _collect_fresh_map()
def _full_rebuild_test_map(
cfg: Config,
*,
target_branch: str,
target_head: str,
logger: logging.Logger,
reason: str,
) -> int:
map_path = resolve_test_map_path(cfg, must_exist=False)
map_path.parent.mkdir(parents=True, exist_ok=True)
logger.warning("test_map full rebuild: %s", reason)
with ephemeral_target_checkout(REPO_ROOT, target_branch):
fresh_map = _run_pytest_and_collect_fresh_map(REPO_ROOT, sys.executable)
if fresh_map is None:
logger.error("test_map full rebuild aborted: pytest failed")
return 1
write_test_map(map_path, fresh_map, built_from_commit=target_head)
logger.info(
"test_map full rebuild wrote %s at built_from_commit=%s",
map_path,
target_head[:12],
)
return 0
def sync_test_map_once(
cfg: Config,
*,
target_branch: str,
logger: logging.Logger,
) -> int:
map_path = resolve_test_map_path(cfg, must_exist=False)
existing_map, built_from_commit = _try_load_existing_map(cfg)
target_head = resolve_target_head(REPO_ROOT, target_branch)
logger.info(
"test_map sync: built_from_commit=%s target=%s (%s)",
(built_from_commit or "(none)")[:12],
target_head[:12],
target_branch,
)
if existing_map is None or not built_from_commit:
return _full_rebuild_test_map(
cfg,
target_branch=target_branch,
target_head=target_head,
logger=logger,
reason="test_map file missing, unreadable, or built_from_commit absent",
)
if built_from_commit == target_head:
logger.info("test_map is up to date with target HEAD")
return 0
if not can_incremental_sync(REPO_ROOT, built_from_commit, target_head):
return _full_rebuild_test_map(
cfg,
target_branch=target_branch,
target_head=target_head,
logger=logger,
reason=(f"built_from_commit {built_from_commit[:12]} is not an ancestor of target HEAD {target_head[:12]}"),
)
touched_paths = fetch_changed_paths(REPO_ROOT, built_from_commit, target_head)
logger.info(
"Incremental update range %s..%s (%d path(s) touched)",
built_from_commit[:12],
target_head[:12],
len(touched_paths),
)
with ephemeral_target_checkout(REPO_ROOT, target_branch):
fresh_map = _run_pytest_and_collect_fresh_map(REPO_ROOT, sys.executable)
if fresh_map is None:
logger.error("test_map sync aborted: pytest failed")
return 1
updated_map = apply_incremental_test_map_update(existing_map, fresh_map, touched_paths)
write_test_map(map_path, updated_map, built_from_commit=target_head)
logger.info("test_map sync wrote %s at built_from_commit=%s", map_path, target_head[:12])
return 0
def _register_shutdown_handlers() -> None:
def _handle_signal(_signum: int, _frame: object) -> None:
_shutdown_flag[0] = True
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
def sync_test_map_watch(
cfg: Config,
*,
target_branch: str,
interval_seconds: float,
logger: logging.Logger,
) -> int:
_shutdown_flag[0] = False
_register_shutdown_handlers()
logger.info(
"test_map sync watch: target=%s interval=%.0fs",
target_branch,
interval_seconds,
)
while not _shutdown_flag[0]:
try:
exit_code = sync_test_map_once(cfg, target_branch=target_branch, logger=logger)
except ConfigError as exc:
logger.error("%s", exc)
return 1
if exit_code != 0:
logger.error(
"test_map sync cycle failed (exit %d); retrying after interval",
exit_code,
)
deadline = time.monotonic() + interval_seconds
while not _shutdown_flag[0] and time.monotonic() < deadline:
time.sleep(0.1)
logger.info("test_map sync watch stopped")
return 0
def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Sync MSMODELING_TEST_MAP_PATH with a target branch HEAD.")
mode = parser.add_mutually_exclusive_group(required=True)
mode.add_argument("--once", action="store_true", help="Run one sync cycle and exit.")
mode.add_argument(
"--watch",
action="store_true",
help="Poll target branch HEAD until interrupted.",
)
parser.add_argument(
"--target-branch",
default=None,
help=(
"Target branch or remote/branch "
"(default: MSMODELING_TEST_MAP_TARGET_BRANCH or MSMODELING_TEST_BASE_BRANCH)."
),
)
parser.add_argument(
"--interval",
type=float,
default=None,
help="Poll interval in seconds for --watch (default: MSMODELING_TEST_MAP_SYNC_INTERVAL or 60).",
)
return parser
def _log_sync_env(logger: logging.Logger, target_branch: str, interval_seconds: float | None) -> None:
logger.info(" MSMODELING_TEST_MAP_TARGET_BRANCH = %s", target_branch)
if interval_seconds is not None:
logger.info(" MSMODELING_TEST_MAP_SYNC_INTERVAL = %s", interval_seconds)
def main(argv: list[str] | None = None) -> int:
logger = setup_logger("test_map_sync")
args = build_arg_parser().parse_args(argv)
cfg = Config.from_env()
log_env_audit(cfg, logger)
if not cfg.test_map_path:
logger.error("MSMODELING_TEST_MAP_PATH is required for test_map sync")
return 1
try:
target_branch = resolve_target_branch(cli_target=args.target_branch, cfg=cfg)
interval = args.interval if args.interval is not None else _parse_sync_interval()
_log_sync_env(logger, target_branch, interval if args.watch else None)
if args.once:
return sync_test_map_once(cfg, target_branch=target_branch, logger=logger)
return sync_test_map_watch(
cfg,
target_branch=target_branch,
interval_seconds=interval,
logger=logger,
)
except ConfigError as exc:
logger.error("%s", exc)
return 1
if __name__ == "__main__":
raise SystemExit(main())