"""CI incremental gate: diff analysis, test selection, pytest.
CLI entry point for run_ci_gate.sh. Orchestrates diff → gate plan → execution.
"""
from __future__ import annotations
import logging
import shlex
import subprocess
import sys
from collections import Counter
from pathlib import Path
from scripts.helpers._config import Config, ConfigError
from scripts.helpers._paths import REPO_ROOT
from scripts.helpers.ci_gate.diff import classify_changes, fetch_diff, resolve_base_ref
from scripts.helpers.ci_gate.errors import format_blocking_errors, format_pytest_failure_hint
from scripts.helpers.ci_gate.gate_policy import (
SourceExemption,
TestExemption,
is_test_exempt,
validate_gate_policy_if_changed,
)
from scripts.helpers.ci_gate.models import (
Baseline,
ChangeSet,
CiGatePlan,
ExecutionPlan,
GateError,
TestRunWave,
)
from scripts.helpers.ci_gate.rules import (
_product_paths,
gate_deleted_source,
gate_deleted_tests,
gate_modified_source,
gate_new_source,
gate_new_tests,
)
from scripts.helpers.common._logging import log_env_audit, setup_logger
from scripts.helpers.common.coverage_config import cov_pytest_args
from scripts.helpers.common.pytest_runner import (
build_pytest_cmd,
count_collected_tests,
filter_collectable_node_ids,
)
from scripts.helpers.common.test_map_loader import load_baseline, prune_deleted_sources
_CHANGED_TEST_MARKER = "not npu"
_REGRESSION_MARKER = "not npu and not nightly and not network"
_FULL_SUITE_MARKER = "not npu"
_COVERAGE_DATA_PATH = REPO_ROOT / ".coverage"
_SAMPLE_NODE_LIMIT = 3
_REASON_CONFIG = "dependency or test configuration changed"
_REASON_CHANGED_TEST = "new or changed test file"
_REASON_REGRESSION = "changed product file mapped regression"
_REASON_DELETED_SOURCE = "deleted product file guard test"
def _remap_renamed_sources(
test_map: dict[str, dict[str, list[str]]],
renames: tuple[tuple[str, str, int], ...],
) -> dict[str, dict[str, list[str]]]:
"""Move coverage-mapping entries from old path to new path for renamed sources."""
remapped = dict(test_map)
for old_path, new_path, _score in renames:
if old_path in remapped:
remapped[new_path] = remapped.pop(old_path)
return remapped
def build_hard_blocking_plan(
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
roots: tuple[str, ...],
) -> tuple[GateError, ...]:
"""Pre-run hard policy: deleted tests/sources and mapping gaps that block pytest."""
blocking: list[GateError] = []
if _product_paths(changes.del_source, roots):
blocking.extend(gate_deleted_source(changes, test_map, roots).errors)
if changes.del_test:
blocking.extend(gate_deleted_tests(changes, test_map).errors)
return tuple(blocking)
def build_coverage_mapping_errors(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path,
) -> tuple[GateError, ...]:
"""Post-run soft policy: new/modified source mapping with coverage fallback."""
blocking: list[GateError] = []
has_new_source = bool(_product_paths(changes.new_source, roots))
if has_new_source:
effective_map = test_map
if _product_paths(changes.del_source, roots):
effective_map = prune_deleted_sources(test_map, changes.del_source)
blocking.extend(
gate_new_source(
repo_root,
changes,
effective_map,
exemptions,
roots,
coverage_path=coverage_path,
check_mapping=True,
).errors
)
if changes.modified_source:
blocking.extend(
gate_modified_source(
repo_root,
changes,
test_map,
exemptions,
roots,
coverage_path=coverage_path,
check_mapping=True,
).errors
)
return tuple(blocking)
def build_ci_gate_plan(
repo_root: Path,
changes: ChangeSet,
baseline: Baseline,
) -> CiGatePlan:
"""Build pytest schedule without pre-run new/modified source mapping checks."""
test_map = baseline.test_map
roots = baseline.roots
full_suite = bool(changes.config)
deleted_source_tests: frozenset[str] = frozenset()
changed_test_nodes: frozenset[str] = frozenset()
regression_tests: frozenset[str] = frozenset()
if _product_paths(changes.del_source, roots):
deleted_source_tests = gate_deleted_source(changes, test_map, roots).tests
if not full_suite and (changes.new_test or changes.modified_test):
changed_test_nodes = gate_new_tests(
changes,
baseline.test_exemptions,
marker=_CHANGED_TEST_MARKER,
full_suite=full_suite,
).tests
if changes.modified_source:
mod_step = gate_modified_source(
repo_root,
changes,
test_map,
baseline.exemptions,
roots,
check_mapping=False,
)
regression_tests = mod_step.tests - deleted_source_tests
return CiGatePlan(
blocking_errors=(),
deleted_source_tests=deleted_source_tests,
changed_test_nodes=changed_test_nodes,
regression_tests=regression_tests,
full_suite=full_suite,
)
def _needs_union_coverage(changes: ChangeSet, roots: tuple[str, ...]) -> bool:
has_product = bool(
_product_paths(changes.new_source, roots)
or _product_paths(changes.del_source, roots)
or changes.modified_source
or changes.renames
)
has_test = bool(changes.new_test or changes.modified_test or changes.del_test)
return has_product or has_test
def _needs_post_run_mapping_check(changes: ChangeSet, roots: tuple[str, ...]) -> bool:
return bool(_product_paths(changes.new_source, roots) or changes.modified_source)
def compute_execution_plan(plan: CiGatePlan, test_exemptions: tuple[TestExemption, ...]) -> ExecutionPlan:
"""Build a deduplicated pytest schedule from a passing gate plan."""
if plan.full_suite:
return ExecutionPlan(
full_suite=True,
waves=(TestRunWave(targets=("tests",), marker=_FULL_SUITE_MARKER),),
reasons={"tests/": _REASON_CONFIG},
)
scheduled: dict[str, str] = {}
for node_id in plan.changed_test_nodes:
scheduled[node_id] = _REASON_CHANGED_TEST
for node_id in plan.regression_tests:
if is_test_exempt(test_exemptions, node_id):
continue
scheduled.setdefault(node_id, _REASON_REGRESSION)
for node_id in plan.deleted_source_tests:
if is_test_exempt(test_exemptions, node_id):
continue
scheduled.setdefault(node_id, _REASON_DELETED_SOURCE)
changed_nodes = tuple(sorted(node for node in scheduled if node in plan.changed_test_nodes))
other_nodes = tuple(sorted(node for node in scheduled if node not in plan.changed_test_nodes))
waves: list[TestRunWave] = []
if changed_nodes:
waves.append(TestRunWave(targets=changed_nodes, marker=_CHANGED_TEST_MARKER))
if other_nodes:
waves.append(TestRunWave(targets=other_nodes, marker=_REGRESSION_MARKER))
return ExecutionPlan(full_suite=False, waves=tuple(waves), reasons=scheduled)
def _collected_count_for_targets(targets: list[str], *, marker: str) -> int:
return count_collected_tests(targets, marker=marker)
def _run_pytest(targets: list[str], *, marker: str, use_cov: bool = False) -> int:
if not targets:
return 0
logger = logging.getLogger("ci_gate")
if all("::" in target for target in targets):
collectable_set = frozenset(filter_collectable_node_ids(targets, marker=marker))
collectable = [target for target in targets if target in collectable_set]
skipped = [target for target in targets if target not in collectable_set]
if skipped:
logger.info("Skipping non-collectable pytest node(s): %s", ", ".join(skipped))
if not collectable:
return 0
run_targets = collectable
collected = len(collectable)
else:
run_targets = targets
collected = _collected_count_for_targets(targets, marker=marker)
extra_args = cov_pytest_args(cov_context=True) if use_cov else ()
cmd = build_pytest_cmd(
sys.executable,
run_targets,
marker=marker,
collected_count=collected,
extra_args=extra_args,
)
logger.info("Running pytest: %s", shlex.join(cmd))
return subprocess.run(cmd, cwd=REPO_ROOT, check=False).returncode
def _sample_nodes(nodes: tuple[str, ...], limit: int = _SAMPLE_NODE_LIMIT) -> str:
if not nodes:
return ""
sample = ", ".join(nodes[:limit])
if len(nodes) > limit:
sample = f"{sample}, ... (+{len(nodes) - limit} more)"
return sample
def _log_execution_plan(logger: logging.Logger, execution: ExecutionPlan) -> None:
if execution.full_suite:
logger.info("Selected full test suite: %s", _REASON_CONFIG)
return
if not execution.has_work:
logger.info("No pytest targets after policy checks; skipping test run")
return
counts = Counter(execution.reasons.values())
for reason, count in sorted(counts.items()):
logger.info("Scheduling %d test node(s): %s", count, reason)
all_nodes = tuple(node for wave in execution.waves for node in wave.targets)
logger.info("Sample node(s): %s", _sample_nodes(all_nodes))
logger.info("Execution uses %d pytest wave(s) after deduplication", len(execution.waves))
def _log_blocking_errors(logger: logging.Logger, errors: tuple[GateError, ...]) -> None:
counts: dict[str, int] = {}
for err in errors:
counts[err.category] = counts.get(err.category, 0) + 1
summary = ", ".join(f"{category}={count}" for category, count in sorted(counts.items()))
logger.error(
"Policy validation failed; pytest skipped (%d issue(s): %s)",
len(errors),
summary or "unknown",
)
def _log_pytest_failure(
logger: logging.Logger,
*,
full_suite: bool,
failed_nodes: tuple[str, ...],
) -> None:
if full_suite:
logger.error("Full test suite failed; see pytest output above")
print("CI gate failed: full test suite did not pass. See pytest output above.")
return
if failed_nodes:
logger.error("Selected tests failed; see pytest output above")
print(format_pytest_failure_hint(failed_nodes))
return
logger.error("Selected tests failed; see pytest output above")
print("CI gate failed: selected tests did not pass. See pytest output above.")
def _print_success_summary(execution: ExecutionPlan, changes: ChangeSet) -> None:
if execution.full_suite:
config_paths = ", ".join(changes.config) if changes.config else "tests/"
print(f"CI gate passed: full test suite ({config_paths})")
return
node_count = sum(len(wave.targets) for wave in execution.waves)
counts = Counter(execution.reasons.values())
reason_parts = ", ".join(f"{count} {reason}" for reason, count in sorted(counts.items()))
if reason_parts:
print(f"CI gate passed: {node_count} test node(s) ({reason_parts})")
else:
print(f"CI gate passed: {node_count} test node(s)")
def main() -> int:
logger = setup_logger()
cfg = Config.from_env()
log_env_audit(cfg, logger)
logger.info("CI gate: classify diff, validate policy, plan tests, run deduplicated selection")
logger.info("Resolving merge-base against %s ...", cfg.base_branch)
try:
merge_base = resolve_base_ref(REPO_ROOT, cfg.base_branch)
except ConfigError as exc:
logger.error("%s", exc)
return 1
logger.info("Merge-base: %s", merge_base[:12])
try:
validate_gate_policy_if_changed(REPO_ROOT, merge_base)
baseline = load_baseline(REPO_ROOT, cfg)
except ConfigError as exc:
logger.error("%s", exc)
return 1
logger.info("Fetching diff ...")
diff_result = fetch_diff(REPO_ROOT, merge_base)
logger.info("Diff: %d files changed", len(diff_result.line_map))
logger.info("Classifying changes ...")
changes = classify_changes(REPO_ROOT, merge_base, diff_result, baseline.discovery, baseline.roots)
if changes.config:
logger.info("Config path(s): %s", ", ".join(changes.config))
logger.info(
"Changes: config=%d new_test=%d mod_test=%d del_test=%d new_source=%d del_source=%d modified=%d renames=%d",
len(changes.config),
len(changes.new_test),
len(changes.modified_test),
len(changes.del_test),
len(changes.new_source),
len(changes.del_source),
len(changes.modified_source),
len(changes.renames),
)
if changes.renames:
logger.info("Remapping coverage mapping for %d renamed source(s) ...", len(changes.renames))
baseline = baseline.__class__(
test_map=_remap_renamed_sources(baseline.test_map, changes.renames),
exemptions=baseline.exemptions,
test_exemptions=baseline.test_exemptions,
discovery=baseline.discovery,
roots=baseline.roots,
)
logger.info("Validating hard-blocking policy ...")
hard_errors = build_hard_blocking_plan(changes, baseline.test_map, baseline.roots)
if hard_errors:
_log_blocking_errors(logger, hard_errors)
print(format_blocking_errors(hard_errors))
return 1
logger.info("Building gate plan ...")
plan = build_ci_gate_plan(REPO_ROOT, changes, baseline)
execution = compute_execution_plan(plan, baseline.test_exemptions)
_log_execution_plan(logger, execution)
use_cov = _needs_union_coverage(changes, baseline.roots)
if use_cov and execution.has_work:
logger.info("Union pytest run will collect branch coverage with per-test context")
if not execution.has_work and not _needs_post_run_mapping_check(changes, baseline.roots):
print("CI gate passed")
return 0
pytest_code = 0
if execution.has_work:
for wave in execution.waves:
pytest_code = _run_pytest(list(wave.targets), marker=wave.marker, use_cov=use_cov)
if pytest_code != 0:
failed_nodes = wave.targets if not execution.full_suite else ()
_log_pytest_failure(logger, full_suite=execution.full_suite, failed_nodes=failed_nodes)
return pytest_code
if _needs_post_run_mapping_check(changes, baseline.roots):
logger.info("Checking new/modified source coverage mapping against collected data ...")
soft_errors = build_coverage_mapping_errors(
REPO_ROOT,
changes,
baseline.test_map,
baseline.exemptions,
baseline.roots,
coverage_path=_COVERAGE_DATA_PATH,
)
if soft_errors:
_log_blocking_errors(logger, soft_errors)
print(format_blocking_errors(soft_errors, pytest_ran=True))
return 1
if execution.has_work:
_print_success_summary(execution, changes)
else:
print("CI gate passed")
return 0
if __name__ == "__main__":
raise SystemExit(main())