"""Gate rules: individual policy checks that produce GateStepResult."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from scripts.helpers.ci_gate.models import (
ChangeSet,
GateError,
GateStepResult,
SourceExemption,
TestExemption,
)
from scripts.helpers.ci_gate.policy import is_exempt, is_test_exempt
from scripts.helpers.ci_gate.test_map_query import (
TestMapIndex,
build_test_map_index,
is_symbol_mapped,
nodes_for_test_file,
source_watchers,
symbol_watchers,
)
from scripts.helpers.common import ast_utils
from scripts.helpers.common.ast_utils import (
MODULE_SYMBOL,
CoverageChecks,
coverage_checks_for_definition,
executable_lines_for_canonical_symbol,
gated_coverage_symbols,
import_symbol_for_definition,
touched_definition_symbols,
)
from scripts.helpers.common.coverage_omit import is_coverage_omitted_source
from scripts.helpers.common.coverage_symbol_check import (
CoverageDataReader,
symbol_lines_covered_in_data,
)
from scripts.helpers.common.pytest_runner import collect_all_test_node_ids
from scripts.helpers.common.test_map_loader import is_product_source
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__)
def _product_paths(paths: tuple[str, ...], prefixes: tuple[str, ...]) -> tuple[str, ...]:
return tuple(path for path in paths if is_product_source(path, prefixes))
def _symbol_lines_from_diff(source_path: Path, canonical_symbol: str, executable: set[int]) -> set[int]:
return executable & executable_lines_for_canonical_symbol(source_path, canonical_symbol)
def _is_import_scope_symbol(symbol: str) -> bool:
return symbol == MODULE_SYMBOL or symbol.endswith(f"::{MODULE_SYMBOL}")
def _skip_import_scope_via_decorator_coverage(
symbol: str,
executable: set[int],
sym_lines: set[int],
decorator_import_lines: set[int],
) -> bool:
if not _is_import_scope_symbol(symbol) or not decorator_import_lines:
return False
return executable <= decorator_import_lines or (bool(sym_lines) and sym_lines <= decorator_import_lines)
def _coverage_checks_pass(
repo_root: Path,
file_path: str,
source_path: Path,
symbol: str,
checks: CoverageChecks,
*,
coverage_path: Path | None,
coverage_data: CoverageDataReader | None = None,
) -> bool:
if checks.import_lines:
import_symbol = import_symbol_for_definition(source_path, symbol)
if not _coverage_fallback_passes(
repo_root,
file_path,
import_symbol,
set(checks.import_lines),
coverage_path=coverage_path,
coverage_data=coverage_data,
):
return False
if checks.strict_lines:
if not _coverage_fallback_passes(
repo_root,
file_path,
symbol,
set(checks.strict_lines),
coverage_path=coverage_path,
coverage_data=coverage_data,
):
return False
elif checks.proxy_lines and not _coverage_fallback_passes(
repo_root,
file_path,
symbol,
set(checks.proxy_lines),
coverage_path=coverage_path,
coverage_data=coverage_data,
):
return False
return True
def _coverage_fallback_passes(
repo_root: Path,
file_path: str,
symbol: str,
lines: set[int],
*,
coverage_path: Path | None,
coverage_data: CoverageDataReader | None = None,
) -> bool:
if not lines or (coverage_data is None and not coverage_path):
return False
require_test_context = not (symbol == MODULE_SYMBOL or symbol.endswith(f"::{MODULE_SYMBOL}"))
if symbol_lines_covered_in_data(
repo_root,
file_path,
symbol,
lines,
coverage_path,
coverage_data=coverage_data,
require_test_context=require_test_context,
):
logger.info(
"Coverage fallback accepted %s::%s (%d executable line(s))",
file_path,
symbol,
len(lines),
)
return True
return False
def _modified_source_mapping_error(
repo_root: Path,
path: str,
source_path: Path,
symbol: str,
executable: set[int],
*,
touched: frozenset[str],
touched_checks: dict[str, CoverageChecks],
decorator_import_lines: set[int],
coverage_path: Path | None,
coverage_data: CoverageDataReader | None,
) -> GateError | None:
if symbol in touched:
checks = touched_checks[symbol]
if not checks.import_lines and not checks.strict_lines and not checks.proxy_lines:
return None
if _coverage_checks_pass(
repo_root,
path,
source_path,
symbol,
checks,
coverage_path=coverage_path,
coverage_data=coverage_data,
):
return None
return GateError(category="modified_source", path=path, symbol=symbol)
sym_lines = _symbol_lines_from_diff(source_path, symbol, executable)
if _skip_import_scope_via_decorator_coverage(symbol, executable, sym_lines, decorator_import_lines):
return None
if not sym_lines:
return None
if _coverage_fallback_passes(
repo_root,
path,
symbol,
sym_lines,
coverage_path=coverage_path,
coverage_data=coverage_data,
):
return None
return GateError(category="modified_source", path=path, symbol=symbol)
def _symbol_error_for_unmapped_source(
repo_root: Path,
path: str,
source_path: Path,
canonical_symbol: str,
*,
coverage_path: Path | None,
coverage_data: CoverageDataReader | None = None,
) -> list[GateError]:
lines = executable_lines_for_canonical_symbol(source_path, canonical_symbol)
if not lines:
return []
if _coverage_fallback_passes(
repo_root,
path,
canonical_symbol,
lines,
coverage_path=coverage_path,
coverage_data=coverage_data,
):
return []
return [
GateError(
category="new_source",
path=path,
symbol=None if canonical_symbol == MODULE_SYMBOL else canonical_symbol,
)
]
def _new_source_errors_for_path(
repo_root: Path,
path: str,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path | None,
coverage_data: CoverageDataReader | None = None,
) -> list[GateError]:
if not path.endswith(".py"):
return []
if is_coverage_omitted_source(path, roots):
return []
source_path = repo_root / path
if not source_path.is_file():
return []
required_symbols = gated_coverage_symbols(source_path)
if not required_symbols:
return []
errors: list[GateError] = []
for canonical_symbol in sorted(required_symbols):
if is_exempt(exemptions, path, canonical_symbol) or is_symbol_mapped(test_map, path, canonical_symbol):
continue
errors.extend(
_symbol_error_for_unmapped_source(
repo_root,
path,
source_path,
canonical_symbol,
coverage_path=coverage_path,
coverage_data=coverage_data,
)
)
return errors
def gate_new_source(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path | None = None,
coverage_data: CoverageDataReader | None = None,
check_mapping: bool = True,
) -> GateStepResult:
if not check_mapping:
return GateStepResult()
errors: list[GateError] = []
for path in changes.new_source:
errors.extend(
_new_source_errors_for_path(
repo_root,
path,
test_map,
exemptions,
roots,
coverage_path=coverage_path,
coverage_data=coverage_data,
)
)
return GateStepResult(errors=tuple(errors))
def gate_new_tests(
changes: ChangeSet,
test_exemptions: tuple[TestExemption, ...],
*,
full_suite: bool = False,
) -> GateStepResult:
"""Collect pytest node ids from new or modified test files without marker filtering."""
if full_suite:
return GateStepResult()
test_files = changes.new_test + changes.modified_test
if not test_files:
return GateStepResult()
collected = collect_all_test_node_ids(list(test_files))
nodes: list[str] = []
all_exempt: list[str] = []
for test_file in test_files:
file_nodes = [node_id for node_id in collected if node_id.split("::", 1)[0] == test_file]
runnable = tuple(node_id for node_id in file_nodes if not is_test_exempt(test_exemptions, node_id))
if file_nodes and not runnable:
all_exempt.append(test_file)
continue
nodes.extend(runnable)
return GateStepResult(tests=frozenset(nodes), all_exempt_test_files=tuple(all_exempt))
def gate_deleted_source(
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
roots: tuple[str, ...],
) -> GateStepResult:
tests: set[str] = set()
errors: list[GateError] = []
deleted_test_files = set(changes.del_test)
index = build_test_map_index(test_map)
for path in changes.del_source:
if not is_product_source(path, roots):
continue
sole_coverage: list[str] = []
guard_tests: set[str] = set()
for node in source_watchers(test_map, path, index=index):
if node.split("::", 1)[0] in deleted_test_files:
continue
guard_tests.add(node)
for symbol in test_map.get(node, {}).get(path, []):
other_watchers = symbol_watchers(test_map, path, symbol, index=index) - {node}
other_watchers -= {n for n in other_watchers if n.split("::", 1)[0] in deleted_test_files}
if not other_watchers:
sole_coverage.append(f"{path}::{symbol}")
tests.update(guard_tests)
if sole_coverage:
detail = "\n".join(sorted(set(sole_coverage)))
errors.append(GateError(category="deleted_source", path=path, detail=detail))
return GateStepResult(errors=tuple(errors), tests=frozenset(tests))
def gate_deleted_tests(
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
) -> GateStepResult:
errors: list[GateError] = []
deleted_source_set = set(changes.del_source)
index = build_test_map_index(test_map)
for deleted_path in changes.del_test:
deleted_nodes = nodes_for_test_file(test_map, deleted_path)
sole_coverage: list[str] = []
for node in deleted_nodes:
for src_file, symbols in test_map.get(node, {}).items():
if src_file in deleted_source_set:
continue
for symbol in symbols:
watchers = symbol_watchers(test_map, src_file, symbol, index=index) - {node}
if not watchers:
sole_coverage.append(f"{src_file}::{symbol}")
if sole_coverage:
detail = "\n".join(sorted(set(sole_coverage)))
errors.append(GateError(category="deleted_test", path=deleted_path, detail=detail))
return GateStepResult(errors=tuple(errors))
def _modified_source_path_gate(
repo_root: Path,
path: str,
raw_lines: frozenset[int],
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
index: TestMapIndex,
*,
coverage_path: Path | None,
coverage_data: CoverageDataReader | None,
check_mapping: bool,
collect_tests: bool,
) -> GateStepResult:
if not is_product_source(path, roots):
return GateStepResult()
if is_coverage_omitted_source(path, roots):
return GateStepResult()
source_path = repo_root / path
executable = ast_utils.filter_executable_lines(source_path, set(raw_lines))
if not executable:
return GateStepResult()
line_symbols = ast_utils.canonical_symbols_for_lines(source_path, executable)
touched = touched_definition_symbols(source_path, executable)
all_symbols = line_symbols | touched
touched_checks: dict[str, CoverageChecks] = {
touched_symbol: coverage_checks_for_definition(source_path, touched_symbol, executable)
for touched_symbol in touched
}
decorator_import_lines = set().union(*(checks.import_lines for checks in touched_checks.values()))
errors: list[GateError] = []
tests: set[str] = set()
for symbol in all_symbols:
mapped = symbol_watchers(test_map, path, symbol, index=index)
if not mapped and _is_import_scope_symbol(symbol):
mapped = source_watchers(test_map, path, index=index)
if mapped:
if collect_tests:
tests.update(mapped)
elif not check_mapping or is_exempt(exemptions, path, symbol):
continue
else:
mapping_error = _modified_source_mapping_error(
repo_root,
path,
source_path,
symbol,
executable,
touched=touched,
touched_checks=touched_checks,
decorator_import_lines=decorator_import_lines,
coverage_path=coverage_path,
coverage_data=coverage_data,
)
if mapping_error is not None:
errors.append(mapping_error)
return GateStepResult(errors=tuple(errors), tests=frozenset(tests))
def _run_modified_source_gate(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path | None = None,
coverage_data: CoverageDataReader | None = None,
check_mapping: bool,
collect_tests: bool,
) -> GateStepResult:
errors: list[GateError] = []
tests: set[str] = set()
index = build_test_map_index(test_map)
for path, raw_lines in changes.modified_source:
path_step = _modified_source_path_gate(
repo_root,
path,
raw_lines,
test_map,
exemptions,
roots,
index,
coverage_path=coverage_path,
coverage_data=coverage_data,
check_mapping=check_mapping,
collect_tests=collect_tests,
)
errors.extend(path_step.errors)
tests.update(path_step.tests)
return GateStepResult(errors=tuple(errors), tests=frozenset(tests))
def gate_modified_source(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path | None = None,
coverage_data: CoverageDataReader | None = None,
check_mapping: bool = True,
) -> GateStepResult:
return _run_modified_source_gate(
repo_root,
changes,
test_map,
exemptions,
roots,
coverage_path=coverage_path,
coverage_data=coverage_data,
check_mapping=check_mapping,
collect_tests=True,
)
def collect_modified_source_mapping_errors(
repo_root: Path,
changes: ChangeSet,
test_map: dict[str, dict[str, list[str]]],
exemptions: tuple[SourceExemption, ...],
roots: tuple[str, ...],
*,
coverage_path: Path,
coverage_data: CoverageDataReader | None,
) -> tuple[GateError, ...]:
return _run_modified_source_gate(
repo_root,
changes,
test_map,
exemptions,
roots,
coverage_path=coverage_path,
coverage_data=coverage_data,
check_mapping=True,
collect_tests=False,
).errors