"""Gate rules: individual policy checks that produce GateStepResult.
Each _gate_* function is a self-contained rule. No orchestration here —
build_ci_gate_plan in main.py decides which rules to run and merges results.
"""
from __future__ import annotations
import logging
from pathlib import Path
from scripts.helpers.ci_gate.gate_policy import SourceExemption, TestExemption, is_exempt, is_test_exempt
from scripts.helpers.ci_gate.models import ChangeSet, GateError, GateStepResult
from scripts.helpers.common import ast_utils
from scripts.helpers.common.coverage_omit import is_coverage_omitted_source
from scripts.helpers.common.coverage_symbol_check import symbol_lines_covered_in_data
from scripts.helpers.common.pytest_runner import collect_test_node_ids
from scripts.helpers.common.test_map_loader import is_product_source
logger = logging.getLogger(__name__)
def _merge_step_results(*steps: GateStepResult) -> GateStepResult:
errors: list[GateError] = []
tests: set[str] = set()
for step in steps:
errors.extend(step.errors)
tests.update(step.tests)
return GateStepResult(errors=tuple(errors), tests=frozenset(tests))
def _product_paths(paths: tuple[str, ...], prefixes: tuple[str, ...]) -> tuple[str, ...]:
return tuple(path for path in paths if is_product_source(path, prefixes))
def _executable_lines_for_symbol(source_path: Path, symbol: str) -> set[int]:
spans = ast_utils.iter_qualified_definition_spans(source_path)
symbol_lines: set[int] = set()
for span in spans:
if span.qualified_name == symbol:
symbol_lines = set(range(span.start_line, span.end_line + 1))
break
if not symbol_lines:
return set()
return ast_utils.filter_executable_lines(source_path, symbol_lines)
def _symbol_lines_from_diff(source_path: Path, symbol: str, executable: set[int]) -> set[int]:
spans = ast_utils.iter_qualified_definition_spans(source_path)
return {line_no for line_no in executable if ast_utils.symbol_for_line(spans, line_no) == symbol}
def _coverage_fallback_passes(
repo_root: Path,
file_path: str,
symbol: str,
lines: set[int],
*,
coverage_path: Path | None,
) -> bool:
if not coverage_path or not lines:
return False
if symbol_lines_covered_in_data(repo_root, file_path, symbol, lines, coverage_path):
logger.info(
"Coverage fallback accepted %s::%s (%d executable line(s))",
file_path,
symbol,
len(lines),
)
return True
return False
def gate_new_source(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path | None = None,
check_mapping: bool = True,
) -> GateStepResult:
if not check_mapping:
return GateStepResult()
errors: list[GateError] = []
for path in changes.new_source:
if not path.endswith(".py"):
continue
if not is_product_source(path, roots):
continue
if is_coverage_omitted_source(path, roots):
continue
if test_map.get(path):
continue
source_path = repo_root / path
if not source_path.is_file():
continue
symbols = ast_utils.top_level_definitions(source_path)
if not symbols:
errors.append(GateError(category="new_source", path=path))
continue
unmapped = [sym for sym in symbols if not is_exempt(exemptions, path, sym)]
for sym in unmapped:
if _coverage_fallback_passes(
repo_root,
path,
sym,
_executable_lines_for_symbol(source_path, sym),
coverage_path=coverage_path,
):
continue
errors.append(GateError(category="new_source", path=path, symbol=sym))
return GateStepResult(errors=tuple(errors))
def gate_new_tests(
changes: ChangeSet,
test_exemptions: tuple[TestExemption, ...],
*,
marker: str,
full_suite: bool = False,
) -> GateStepResult:
"""Collect runnable pytest node ids from new or modified test files.
Files whose collected nodes are all listed in exemptions.tests contribute
nothing — they must not be scheduled for execution.
"""
if full_suite:
return GateStepResult()
test_files = changes.new_test + changes.modified_test
if not test_files:
return GateStepResult()
collected = collect_test_node_ids(list(test_files), marker=marker)
nodes: list[str] = []
for test_file in test_files:
file_nodes = [node_id for node_id in collected if node_id.split("::", 1)[0] == test_file]
runnable = tuple(node_id for node_id in file_nodes if not is_test_exempt(test_exemptions, node_id))
if file_nodes and not runnable:
continue
nodes.extend(runnable)
return GateStepResult(tests=frozenset(nodes))
def gate_deleted_source(
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
roots: tuple[str, ...],
) -> GateStepResult:
errors: list[GateError] = []
tests: set[str] = set()
deleted_test_files = set(changes.del_test)
for path in changes.del_source:
if not is_product_source(path, roots):
continue
file_map = test_map.get(path)
if not file_map:
errors.append(GateError(category="deleted_source", path=path))
continue
guard_tests = {
tid for tids in file_map.values() for tid in tids if tid.split("::", 1)[0] not in deleted_test_files
}
tests.update(guard_tests)
return GateStepResult(errors=tuple(errors), tests=frozenset(tests))
def gate_deleted_tests(
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
) -> GateStepResult:
errors: list[GateError] = []
deleted_source_set = set(changes.del_source)
for deleted_path in changes.del_test:
sole_coverage: list[str] = []
normalized_deleted = deleted_path.split("::", 1)[0]
for src_file, symbols in test_map.items():
if src_file in deleted_source_set:
continue
for symbol, test_ids in symbols.items():
normalized_paths = {test_id.split("::", 1)[0] for test_id in test_ids}
if len(normalized_paths) == 1 and normalized_deleted in normalized_paths:
sole_coverage.append(f"{src_file}::{symbol}")
if sole_coverage:
detail = "\n".join(sole_coverage)
errors.append(GateError(category="deleted_test", path=deleted_path, detail=detail))
return GateStepResult(errors=tuple(errors))
def gate_modified_source(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path | None = None,
check_mapping: bool = True,
) -> GateStepResult:
errors: list[GateError] = []
tests: set[str] = set()
for path, raw_lines in changes.modified_source:
if not is_product_source(path, roots):
continue
if is_coverage_omitted_source(path, roots):
continue
source_path = repo_root / path
executable = ast_utils.filter_executable_lines(source_path, set(raw_lines))
if not executable:
continue
symbols = ast_utils.symbols_for_lines(source_path, executable)
file_map = test_map.get(path, {})
for symbol in symbols:
mapped = file_map.get(symbol)
if mapped:
tests.update(mapped)
elif (
not check_mapping
or is_exempt(exemptions, path, symbol)
or _coverage_fallback_passes(
repo_root,
path,
symbol,
_symbol_lines_from_diff(source_path, symbol, executable),
coverage_path=coverage_path,
)
):
continue
else:
errors.append(GateError(category="modified_source", path=path, symbol=symbol))
return GateStepResult(errors=tuple(errors), tests=frozenset(tests))