"""Query helpers for test_node -> source_file -> symbols test_map."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Iterable

TestMap = dict[str, dict[str, list[str]]]


@dataclass(frozen=True, slots=True)
class TestMapIndex:
    """Reverse indexes for O(1) source/symbol watcher lookups."""

    _source_watchers: dict[str, frozenset[str]]
    _symbol_watchers: dict[tuple[str, str], frozenset[str]]

    def source_watchers(self, source_file: str) -> frozenset[str]:
        return self._source_watchers.get(source_file, frozenset())

    def symbol_watchers(self, source_file: str, symbol: str) -> frozenset[str]:
        return self._symbol_watchers.get((source_file, symbol), frozenset())


def build_test_map_index(test_map: TestMap) -> TestMapIndex:
    """Build reverse indexes from a node-oriented test_map (one O(N) pass)."""
    by_source: dict[str, set[str]] = {}
    by_symbol: dict[tuple[str, str], set[str]] = {}
    for node, sources in test_map.items():
        for src_path, symbols in sources.items():
            by_source.setdefault(src_path, set()).add(node)
            for symbol in symbols:
                by_symbol.setdefault((src_path, symbol), set()).add(node)
    return TestMapIndex(
        _source_watchers={path: frozenset(nodes) for path, nodes in by_source.items()},
        _symbol_watchers={key: frozenset(nodes) for key, nodes in by_symbol.items()},
    )


def nodes_for_test_file(test_map: TestMap, test_file: str) -> frozenset[str]:
    prefix = f"{test_file}::"
    return frozenset(node for node in test_map if node.startswith(prefix))


def symbol_watchers(
    test_map: TestMap,
    source_file: str,
    symbol: str,
    *,
    index: TestMapIndex | None = None,
) -> frozenset[str]:
    """Return test nodes watching *symbol* on *source_file*.

    When *index* is omitted, scans the full map (O(N) per call). Pass a
    :class:`TestMapIndex` from :func:`build_test_map_index` when querying repeatedly.
    """
    if index is not None:
        return index.symbol_watchers(source_file, symbol)
    watchers: set[str] = set()
    for node, sources in test_map.items():
        if symbol in sources.get(source_file, ()):
            watchers.add(node)
    return frozenset(watchers)


def source_watchers(
    test_map: TestMap,
    source_file: str,
    *,
    index: TestMapIndex | None = None,
) -> frozenset[str]:
    """Return test nodes that reference *source_file*.

    When *index* is omitted, scans the full map (O(N) per call). Pass a
    :class:`TestMapIndex` from :func:`build_test_map_index` when querying repeatedly.
    """
    if index is not None:
        return index.source_watchers(source_file)
    watchers: set[str] = set()
    for node, sources in test_map.items():
        if sources.get(source_file):
            watchers.add(node)
    return frozenset(watchers)


def is_source_file_mapped(test_map: TestMap, source_file: str) -> bool:
    return bool(source_watchers(test_map, source_file))


def is_symbol_mapped(test_map: TestMap, source_file: str, symbol: str) -> bool:
    return bool(symbol_watchers(test_map, source_file, symbol))


def symbols_mapped_for_source(test_map: TestMap, source_file: str) -> frozenset[str]:
    symbols: set[str] = set()
    for sources in test_map.values():
        symbols.update(sources.get(source_file, ()))
    return frozenset(symbols)


def prune_deleted_sources(test_map: TestMap, deleted: Iterable[str]) -> TestMap:
    deleted_set = set(deleted)
    pruned: TestMap = {}
    for node, sources in test_map.items():
        kept = {path: syms for path, syms in sources.items() if path not in deleted_set}
        if kept:
            pruned[node] = kept
    return pruned