"""AST helpers for mapping tests to source symbols and filtering executable lines."""
from __future__ import annotations
import ast
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Final
_NON_EXECUTABLE_ASSIGN_NAMES: Final = frozenset({"__all__", "__version__"})
@dataclass(frozen=True, slots=True)
class DefinitionSpan:
qualified_name: str
start_line: int
end_line: int
@lru_cache(maxsize=128)
def _parse_cached(path_str: str, mtime_ns: int) -> ast.AST:
return ast.parse(Path(path_str).read_text(encoding="utf-8"), filename=path_str)
def _get_cached_tree(path: Path) -> ast.AST:
"""Parse file to AST with caching by path + mtime."""
return _parse_cached(str(path), path.stat().st_mtime_ns)
def _end_line(node: ast.AST) -> int:
return getattr(node, "end_lineno", node.lineno)
def top_level_definitions(path: Path) -> list[str]:
"""Return names of top-level functions, async functions, and classes."""
tree = _get_cached_tree(path)
out: list[str] = []
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) or isinstance(node, ast.ClassDef):
out.append(node.name)
return out
def iter_qualified_definition_spans(path: Path) -> list[DefinitionSpan]:
"""Return ``DefinitionSpan`` for top-level defs and class methods."""
try:
tree = _get_cached_tree(path)
except (OSError, SyntaxError):
return []
spans: list[DefinitionSpan] = []
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
spans.append(DefinitionSpan(node.name, node.lineno, _end_line(node)))
elif isinstance(node, ast.ClassDef):
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
spans.append(
DefinitionSpan(
f"{node.name}.{item.name}",
item.lineno,
_end_line(item),
)
)
return spans
def symbol_for_line(spans: list[DefinitionSpan], line: int) -> str | None:
"""Pick the smallest enclosing span (innermost definition)."""
containing = [(s, s.end_line - s.start_line + 1) for s in spans if s.start_line <= line <= s.end_line]
if not containing:
return None
span, _ = min(containing, key=lambda item: item[1])
return span.qualified_name
def _collect_non_executable_lines(tree: ast.AST) -> set[int]:
"""Single walk: collect docstring lines + non-executable statement lines."""
lines: set[int] = set()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
if node.body and isinstance(node.body[0], ast.Expr):
value = node.body[0].value
if isinstance(value, ast.Constant) and isinstance(value.value, str):
start = node.body[0].lineno
end = _end_line(node.body[0])
lines.update(range(start, end + 1))
if isinstance(node, ast.AnnAssign) and node.value is None:
start = node.lineno
end = _end_line(node)
lines.update(range(start, end + 1))
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id in _NON_EXECUTABLE_ASSIGN_NAMES:
start = node.lineno
end = _end_line(node)
lines.update(range(start, end + 1))
return lines
def filter_executable_lines(path: Path, changed_lines: set[int] | frozenset[int]) -> set[int]:
"""Drop comment-only, docstring, type-only, and __all__/__version__ diff lines."""
if not changed_lines:
return set()
try:
text = path.read_text(encoding="utf-8")
except OSError:
return set()
source_lines = text.splitlines()
try:
tree = _get_cached_tree(path)
except SyntaxError:
return {
line_no
for line_no in changed_lines
if 1 <= line_no <= len(source_lines) and _line_text_is_executable(source_lines[line_no - 1])
}
skip = _collect_non_executable_lines(tree)
executable: set[int] = set()
for line_no in changed_lines:
if line_no in skip:
continue
if line_no < 1 or line_no > len(source_lines):
continue
if _line_text_is_executable(source_lines[line_no - 1]):
executable.add(line_no)
return executable
def _line_text_is_executable(line: str) -> bool:
stripped = line.strip()
return bool(stripped) and not stripped.startswith("#")
def symbols_for_lines(path: Path, line_numbers: set[int]) -> set[str]:
"""Return symbols that enclose the given line numbers."""
if not line_numbers:
return set()
spans = iter_qualified_definition_spans(path)
symbols: set[str] = set()
for line_no in line_numbers:
symbol = symbol_for_line(spans, line_no)
if symbol is not None:
symbols.add(symbol)
return symbols