"""AST helpers for mapping tests to source symbols and filtering executable lines."""
from __future__ import annotations
import ast
import json
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Final
_NON_EXECUTABLE_ASSIGN_NAMES: Final = frozenset({"__all__", "__version__"})
MODULE_SYMBOL: Final = "%"
@dataclass(frozen=True, slots=True)
class DefinitionSpan:
qualified_name: str
start_line: int
end_line: int
@dataclass(frozen=True, slots=True)
class FileSymbols:
definitions: tuple[DefinitionSpan, ...]
class_spans: tuple[DefinitionSpan, ...]
decorator_lines: frozenset[int]
@dataclass(frozen=True, slots=True)
class CoverageChecks:
import_lines: frozenset[int]
strict_lines: frozenset[int]
proxy_lines: frozenset[int]
@dataclass(frozen=True, slots=True)
class ShadowWarning:
file: str
line: int
name: str
shadowed_by_line: int
@lru_cache(maxsize=128)
def _parse_cached(path_str: str, mtime_ns: int) -> ast.Module:
return ast.parse(Path(path_str).read_text(encoding="utf-8"), filename=path_str)
def _get_cached_tree(path: Path) -> ast.Module:
"""Parse file to AST with caching by path + mtime."""
return _parse_cached(str(path), path.stat().st_mtime_ns)
def _end_line(node: ast.AST) -> int:
end_lineno = getattr(node, "end_lineno", None)
if isinstance(end_lineno, int):
return end_lineno
lineno = getattr(node, "lineno", None)
if isinstance(lineno, int):
return lineno
return 0
def _stable_decorator_unparse(node: ast.AST) -> str:
if isinstance(node, ast.Constant):
if isinstance(node.value, str):
return json.dumps(node.value)
return ast.unparse(node)
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return f"{_stable_decorator_unparse(node.value)}.{node.attr}"
if isinstance(node, ast.Call):
func = _stable_decorator_unparse(node.func)
arg_parts = [_stable_decorator_unparse(arg) for arg in node.args]
arg_parts.extend(
f"{keyword.arg}={_stable_decorator_unparse(keyword.value)}"
for keyword in node.keywords
if keyword.arg is not None
)
arg_parts.extend(
f"**{_stable_decorator_unparse(keyword.value)}" for keyword in node.keywords if keyword.arg is None
)
return f"{func}({', '.join(arg_parts)})"
if isinstance(node, ast.Subscript):
return f"{_stable_decorator_unparse(node.value)}[{_stable_decorator_unparse(node.slice)}]"
if isinstance(node, ast.Tuple):
inner = ", ".join(_stable_decorator_unparse(elt) for elt in node.elts)
return f"({inner},)" if len(node.elts) == 1 else f"({inner})"
if isinstance(node, ast.List):
inner = ", ".join(_stable_decorator_unparse(elt) for elt in node.elts)
return f"[{inner}]"
return ast.unparse(node)
def _mangle_decorator_suffix(decorator_list: list[ast.expr]) -> str:
return "@".join(_stable_decorator_unparse(dec) for dec in decorator_list)
def _collect_decorator_lines(nodes: list[ast.stmt]) -> set[int]:
lines: set[int] = set()
for node in nodes:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
lines.update(_decorator_line_numbers(node))
if isinstance(node, ast.ClassDef):
lines.update(_collect_decorator_lines(list(node.body)))
return lines
def _definition_start_line(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
) -> int:
if node.decorator_list:
return min(dec.lineno for dec in node.decorator_list)
return node.lineno
def _decorator_line_numbers(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
) -> set[int]:
lines: set[int] = set()
for dec in node.decorator_list:
lines.update(range(dec.lineno, _end_line(dec) + 1))
return lines
def _is_protocol_ellipsis_stub(node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""True when the function body is a single ``...`` (typing.Protocol stub)."""
if len(node.body) != 1:
return False
stmt = node.body[0]
return isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis
def _canonical_function_name(
node: ast.FunctionDef | ast.AsyncFunctionDef,
*,
class_name: str | None,
) -> str:
name = node.name
local = f"{name}@{_mangle_decorator_suffix(node.decorator_list)}" if node.decorator_list else name
if class_name is None:
return local
return f"{class_name}::{local}"
def _canonical_class_decorator_symbol(node: ast.ClassDef) -> str | None:
"""Mangled decorator identity for a class; used only for ``_find_definition_node`` lookup.
Not a ``test_map`` or gate key — class decorator lines gate via ``{name}::%``.
"""
if not node.decorator_list:
return None
return f"{node.name}@{_mangle_decorator_suffix(node.decorator_list)}"
def _effective_definition_spans(
nodes: list[ast.stmt],
*,
class_name: str | None,
) -> tuple[list[DefinitionSpan], list[tuple[int, str, int]]]:
"""Last-wins per mangled qualified name within one module or class body.
Each definition is keyed by its mangled qualified name (e.g. ``foo``,
``_@deco("a")``, ``Bar::run@staticmethod``). Identical mangled collisions
overwrite earlier spans and emit shadow warnings; different mangled symbols
for the same bare name (e.g. ``_@deco("a")`` vs ``_@deco("b")``) coexist.
"""
last_by_qualified: dict[str, tuple[ast.FunctionDef | ast.AsyncFunctionDef, DefinitionSpan]] = {}
shadows: list[tuple[int, str, int]] = []
for node in nodes:
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if _is_protocol_ellipsis_stub(node):
continue
qualified = _canonical_function_name(node, class_name=class_name)
span = DefinitionSpan(
qualified,
_definition_start_line(node),
_end_line(node),
)
if qualified in last_by_qualified:
prev_node, _ = last_by_qualified[qualified]
shadows.append((prev_node.lineno, qualified, node.lineno))
last_by_qualified[qualified] = (node, span)
return [span for _, span in last_by_qualified.values()], shadows
def collect_file_symbols(path: Path) -> FileSymbols:
"""Return canonical definition and class spans for *path*."""
try:
tree = _get_cached_tree(path)
except (OSError, SyntaxError):
return FileSymbols((), (), frozenset())
definitions, _ = _effective_definition_spans(list(tree.body), class_name=None)
class_spans: list[DefinitionSpan] = []
for node in tree.body:
if isinstance(node, ast.ClassDef):
class_spans.append(DefinitionSpan(node.name, _definition_start_line(node), _end_line(node)))
method_spans, _ = _effective_definition_spans(list(node.body), class_name=node.name)
definitions.extend(method_spans)
decorator_lines = frozenset(_collect_decorator_lines(list(tree.body)))
return FileSymbols(tuple(definitions), tuple(class_spans), decorator_lines)
def collect_shadow_warnings(path: Path) -> tuple[ShadowWarning, ...]:
"""Return shadow warnings for identical mangled qualified name collisions.
Intentionally re-walks module/class bodies via ``_effective_definition_spans``
(AST parse is cached in ``_get_cached_tree``; symbols path is separate).
"""
try:
tree = _get_cached_tree(path)
except (OSError, SyntaxError):
return ()
file_str = str(path)
warnings: list[ShadowWarning] = []
_, module_shadows = _effective_definition_spans(list(tree.body), class_name=None)
for line, qualified, shadowed_by_line in module_shadows:
warnings.append(ShadowWarning(file_str, line, qualified, shadowed_by_line))
for node in tree.body:
if isinstance(node, ast.ClassDef):
_, class_shadows = _effective_definition_spans(list(node.body), class_name=node.name)
for line, qualified, shadowed_by_line in class_shadows:
warnings.append(ShadowWarning(file_str, line, qualified, shadowed_by_line))
return tuple(warnings)
def canonical_symbol_for_line(symbols: FileSymbols, line: int) -> str:
"""Map a source line to its canonical symbol, including module/class fallbacks."""
if line in symbols.decorator_lines:
for cls in symbols.class_spans:
if cls.start_line <= line <= cls.end_line:
return f"{cls.qualified_name}::{MODULE_SYMBOL}"
return MODULE_SYMBOL
containing = [
(span, span.end_line - span.start_line + 1)
for span in symbols.definitions
if span.start_line <= line <= span.end_line
]
if containing:
span, _ = min(containing, key=lambda item: item[1])
return span.qualified_name
for cls in symbols.class_spans:
if cls.start_line <= line <= cls.end_line:
return f"{cls.qualified_name}::{MODULE_SYMBOL}"
return MODULE_SYMBOL
def assert_canonical_symbol(symbol: str) -> None:
"""Raise ValueError when *symbol* is empty."""
if not symbol:
raise ValueError(f"symbol must be non-empty, got {symbol!r}")
def _definition_executable_lines(
file_symbols: FileSymbols,
executable: set[int],
canonical_symbol: str,
) -> set[int] | None:
for span in file_symbols.definitions:
if canonical_symbol != span.qualified_name:
continue
span_lines = set(range(span.start_line, span.end_line + 1))
return span_lines & executable
return None
def _class_body_executable_lines(
file_symbols: FileSymbols,
executable: set[int],
canonical_symbol: str,
) -> set[int] | None:
class_suffix = f"::{MODULE_SYMBOL}"
if not canonical_symbol.endswith(class_suffix):
return None
class_name = canonical_symbol[: -len(class_suffix)]
covered_by_methods: set[int] = set()
for span in file_symbols.definitions:
if span.qualified_name.startswith(f"{class_name}::"):
covered_by_methods |= set(range(span.start_line, span.end_line + 1))
for cls in file_symbols.class_spans:
if cls.qualified_name != class_name:
continue
class_lines = set(range(cls.start_line, cls.end_line + 1))
return class_lines & executable - covered_by_methods
return set()
def _module_executable_lines(file_symbols: FileSymbols, executable: set[int]) -> set[int]:
covered: set[int] = set()
for span in file_symbols.definitions:
covered |= set(range(span.start_line, span.end_line + 1)) & executable
for cls in file_symbols.class_spans:
class_lines = set(range(cls.start_line, cls.end_line + 1))
covered |= class_lines & executable
return executable - covered
def _read_source_lines(path: Path) -> list[str] | None:
try:
return path.read_text(encoding="utf-8").splitlines()
except OSError:
return None
def _filter_executable_source_lines(
path: Path,
source_lines: list[str],
changed_lines: set[int] | frozenset[int],
) -> set[int]:
if not changed_lines:
return set()
try:
tree = _get_cached_tree(path)
except SyntaxError:
return {
line_no
for line_no in changed_lines
if 1 <= line_no <= len(source_lines) and _line_text_is_executable(source_lines[line_no - 1])
}
skip = _collect_non_executable_lines(tree)
executable: set[int] = set()
for line_no in changed_lines:
if line_no in skip:
continue
if line_no < 1 or line_no > len(source_lines):
continue
if _line_text_is_executable(source_lines[line_no - 1]):
executable.add(line_no)
return executable
def _all_executable_lines(path: Path, source_lines: list[str]) -> set[int]:
if not source_lines:
return set()
return _filter_executable_source_lines(path, source_lines, set(range(1, len(source_lines) + 1)))
def executable_lines_for_canonical_symbol(path: Path, canonical_symbol: str) -> set[int]:
"""Return executable source lines attributed to one canonical gated symbol."""
source_lines = _read_source_lines(path)
if not source_lines:
return set()
executable = _all_executable_lines(path, source_lines)
if not executable:
return set()
file_symbols = collect_file_symbols(path)
definition_lines = _definition_executable_lines(file_symbols, executable, canonical_symbol)
if definition_lines is not None:
return definition_lines
class_lines = _class_body_executable_lines(file_symbols, executable, canonical_symbol)
if class_lines is not None:
return class_lines
if canonical_symbol == MODULE_SYMBOL:
return _module_executable_lines(file_symbols, executable)
return set()
def _find_class_node(tree: ast.Module, class_name: str) -> ast.ClassDef | None:
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
return node
return None
def _find_definition_node(
path: Path,
symbol: str,
) -> ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | None:
try:
tree = _get_cached_tree(path)
except (OSError, SyntaxError):
return None
class_suffix = f"::{MODULE_SYMBOL}"
if symbol.endswith(class_suffix):
return _find_class_node(tree, symbol[: -len(class_suffix)])
for node in tree.body:
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and _canonical_function_name(node, class_name=None) == symbol
):
return node
if isinstance(node, ast.ClassDef):
deco_symbol = _canonical_class_decorator_symbol(node)
if deco_symbol == symbol:
return node
for item in node.body:
if (
isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
and not _is_protocol_ellipsis_stub(item)
and _canonical_function_name(item, class_name=node.name) == symbol
):
return item
return None
def _body_executable_lines(path: Path, symbol: str) -> set[int]:
all_lines = executable_lines_for_canonical_symbol(path, symbol)
node = _find_definition_node(path, symbol)
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
return all_lines
header = {node.lineno}
return all_lines - header - _decorator_line_numbers(node)
def _proxy_body_lines(path: Path, symbol: str) -> set[int]:
node = _find_definition_node(path, symbol)
if isinstance(node, ast.ClassDef):
return executable_lines_for_canonical_symbol(path, f"{node.name}::{MODULE_SYMBOL}")
return _body_executable_lines(path, symbol)
def import_symbol_for_definition(path: Path, symbol: str) -> str:
"""Return ``%`` or ``Class::%`` for decorator import coverage on a mangled definition."""
node = _find_definition_node(path, symbol)
if isinstance(node, ast.ClassDef):
return f"{node.name}::{MODULE_SYMBOL}"
if "::" in symbol:
return f"{symbol.split('::', 1)[0]}::{MODULE_SYMBOL}"
return MODULE_SYMBOL
def touched_definition_symbols(path: Path, changed_lines: set[int] | frozenset[int]) -> frozenset[str]:
"""Mangled definition symbols whose span intersects *changed_lines*."""
changed = set(changed_lines)
if not changed:
return frozenset()
file_symbols = collect_file_symbols(path)
touched: set[str] = set()
for span in file_symbols.definitions:
span_lines = set(range(span.start_line, span.end_line + 1))
if span_lines & changed:
touched.add(span.qualified_name)
try:
tree = _get_cached_tree(path)
except (OSError, SyntaxError):
return frozenset(touched)
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
header_or_deco = {node.lineno} | _decorator_line_numbers(node)
if changed & header_or_deco:
touched.add(f"{node.name}::{MODULE_SYMBOL}")
return frozenset(touched)
def coverage_checks_for_definition(
path: Path,
symbol: str,
changed_lines: set[int] | frozenset[int],
) -> CoverageChecks:
"""Return import / strict / proxy line sets for gate coverage fallback."""
changed = set(changed_lines)
if not changed:
return CoverageChecks(frozenset(), frozenset(), frozenset())
node = _find_definition_node(path, symbol)
if node is None:
return CoverageChecks(frozenset(), frozenset(), frozenset())
decorator_lines = _decorator_line_numbers(node)
def_line = {node.lineno}
if isinstance(node, ast.ClassDef):
body_executable = _proxy_body_lines(path, symbol) - _decorator_line_numbers(node)
else:
body_executable = _body_executable_lines(path, symbol)
decorator_hit = changed & decorator_lines
def_hit = changed & def_line
body_hit = changed & body_executable
import_lines = frozenset(decorator_hit)
strict_lines = frozenset(body_hit)
proxy_lines: frozenset[int] = frozenset()
if not body_hit and (def_hit or decorator_hit):
proxy_lines = frozenset(body_executable)
return CoverageChecks(import_lines, strict_lines, proxy_lines)
def gated_coverage_symbols(path: Path) -> frozenset[str]:
"""Canonical symbols that require test_map coverage for a newly added source file."""
source_lines = _read_source_lines(path)
if not source_lines:
return frozenset()
executable = _all_executable_lines(path, source_lines)
if not executable:
return frozenset()
file_symbols = collect_file_symbols(path)
required: set[str] = set()
covered_lines: set[int] = set()
for span in file_symbols.definitions:
span_lines = set(range(span.start_line, span.end_line + 1))
exec_lines = span_lines & executable
if exec_lines:
required.add(span.qualified_name)
covered_lines |= exec_lines
class_suffix = f"::{MODULE_SYMBOL}"
for cls in file_symbols.class_spans:
class_lines = set(range(cls.start_line, cls.end_line + 1))
class_body_exec = class_lines & executable - covered_lines
if class_body_exec:
required.add(f"{cls.qualified_name}{class_suffix}")
covered_lines |= class_body_exec
module_exec = executable - covered_lines
if module_exec:
required.add(MODULE_SYMBOL)
return frozenset(required)
def iter_canonical_definition_spans(path: Path) -> list[DefinitionSpan]:
"""Return canonical ``DefinitionSpan`` values for top-level defs and methods."""
return list(collect_file_symbols(path).definitions)
def top_level_definitions(path: Path) -> list[str]:
"""Return names of top-level functions, async functions, and classes."""
tree = _get_cached_tree(path)
return [node.name for node in tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))]
def iter_qualified_definition_spans(path: Path) -> list[DefinitionSpan]:
"""Return canonical definition spans for gate and test_map tooling."""
return list(iter_canonical_definition_spans(path))
def symbol_for_line(spans: list[DefinitionSpan], line: int) -> str | None:
"""Pick the smallest enclosing canonical span (innermost definition)."""
containing = [
(span, span.end_line - span.start_line + 1) for span in spans if span.start_line <= line <= span.end_line
]
if not containing:
return None
span, _ = min(containing, key=lambda item: item[1])
return span.qualified_name
def canonical_symbol_for_path_line(path: Path, line: int) -> str:
"""Resolve a line in *path* to the canonical symbol id."""
return canonical_symbol_for_line(collect_file_symbols(path), line)
def _collect_non_executable_lines(tree: ast.AST) -> set[int]:
"""Single walk: collect docstring lines + non-executable statement lines."""
lines: set[int] = set()
for node in ast.walk(tree):
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module))
and node.body
and isinstance(node.body[0], ast.Expr)
):
value = node.body[0].value
if isinstance(value, ast.Constant) and isinstance(value.value, str):
start = node.body[0].lineno
end = _end_line(node.body[0])
lines.update(range(start, end + 1))
if isinstance(node, ast.AnnAssign) and node.value is None:
start = node.lineno
end = _end_line(node)
lines.update(range(start, end + 1))
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id in _NON_EXECUTABLE_ASSIGN_NAMES:
start = node.lineno
end = _end_line(node)
lines.update(range(start, end + 1))
return lines
def filter_executable_lines(path: Path, changed_lines: set[int] | frozenset[int]) -> set[int]:
"""Drop comment-only, docstring, type-only, and __all__/__version__ diff lines."""
if not changed_lines:
return set()
source_lines = _read_source_lines(path)
if source_lines is None:
return set()
return _filter_executable_source_lines(path, source_lines, changed_lines)
def _line_text_is_executable(line: str) -> bool:
stripped = line.strip()
return bool(stripped) and not stripped.startswith("#")
def symbols_for_lines(path: Path, line_numbers: set[int]) -> set[str]:
"""Return canonical symbols that enclose the given line numbers."""
return canonical_symbols_for_lines(path, line_numbers)
def canonical_symbols_for_lines(path: Path, line_numbers: set[int]) -> set[str]:
"""Return canonical symbols that enclose the given line numbers."""
if not line_numbers:
return set()
source_lines = _read_source_lines(path)
if not source_lines:
return set()
line_count = len(source_lines)
symbols = collect_file_symbols(path)
attributed: set[str] = set()
for line_no in line_numbers:
if line_no < 1 or line_no > line_count:
continue
attributed.add(canonical_symbol_for_line(symbols, line_no))
return attributed