"""Application-level BM25 index for keyword retrieval.

Built from ChromaDB documents at startup, kept in sync via upsert/delete.
Uses rank_bm25 library with English tokenization (NLTK).
"""

from __future__ import annotations

import logging
import threading
from dataclasses import dataclass

from rank_bm25 import BM25Okapi

logger = logging.getLogger(__name__)

# Lazy-import NLTK to avoid startup cost
_nltk_ready = False


def _ensure_nltk():
    global _nltk_ready
    if _nltk_ready:
        return
    import nltk
    for res in ("punkt_tab", "stopwords"):
        try:
            nltk.data.find(f"tokenizers/{res}" if "punkt" in res else f"corpora/{res}")
        except LookupError:
            nltk.download(res, quiet=True)
    _nltk_ready = True


def _tokenize(text: str) -> list[str]:
    """Tokenize English text: lowercase, stem, remove stopwords."""
    _ensure_nltk()
    import nltk
    from nltk.stem import PorterStemmer

    stemmer = PorterStemmer()
    stops = set(nltk.corpus.stopwords.words("english"))
    tokens = nltk.word_tokenize(text.lower())
    return [stemmer.stem(t) for t in tokens if t.isalnum() and t not in stops]


@dataclass
class Bm25Hit:
    """A single BM25 search result."""
    id: str
    text: str
    score: float
    metadata: dict


class Bm25Index:
    """Thread-safe application-level BM25 index.

    Lifecycle:
    1. build_from_chroma() — initial load at startup
    2. upsert() / delete() — incremental updates from outbox worker
    3. search() — query-time retrieval
    """

    def __init__(self) -> None:
        self._lock = threading.Lock()
        self._ids: list[str] = []       # parallel to _texts
        self._texts: list[str] = []
        self._metas: list[dict] = []
        self._tokenized: list[list[str]] = []
        self._bm25: BM25Okapi | None = None

    # -- Build / Rebuild -----------------------------------------------------

    def build_from_chroma(self, collection) -> None:
        """Load all documents from ChromaDB collection and build index."""
        count = collection.count()
        if count == 0:
            logger.info("Bm25Index: empty collection, nothing to index")
            return

        # Fetch all documents in batches
        batch_size = 1000
        all_ids, all_texts, all_metas = [], [], []
        for offset in range(0, count, batch_size):
            results = collection.get(
                include=["documents", "metadatas"],
                limit=batch_size,
                offset=offset,
            )
            all_ids.extend(results["ids"])
            all_texts.extend(results["documents"] or [])
            all_metas.extend(results["metadatas"] or [])

        with self._lock:
            self._ids = all_ids
            self._texts = all_texts
            self._metas = [m or {} for m in all_metas]
            self._tokenized = [_tokenize(t) for t in all_texts]
            self._bm25 = BM25Okapi(self._tokenized)

        logger.info("Bm25Index: built from ChromaDB, %d documents indexed", len(all_ids))

    def _rebuild(self) -> None:
        """Rebuild BM25 index from current data (caller must hold lock)."""
        self._tokenized = [_tokenize(t) for t in self._texts]
        self._bm25 = BM25Okapi(self._tokenized)

    # -- Incremental Updates -------------------------------------------------

    def upsert(self, doc_id: str, text: str, metadata: dict) -> None:
        """Insert or update a single document."""
        with self._lock:
            # Check if ID exists
            for i, existing_id in enumerate(self._ids):
                if existing_id == doc_id:
                    self._texts[i] = text
                    self._metas[i] = metadata or {}
                    self._tokenized[i] = _tokenize(text)
                    self._bm25 = BM25Okapi(self._tokenized)
                    return
            # New document
            self._ids.append(doc_id)
            self._texts.append(text)
            self._metas.append(metadata or {})
            self._tokenized.append(_tokenize(text))
            self._bm25 = BM25Okapi(self._tokenized)

    def delete(self, doc_ids: list[str]) -> None:
        """Remove documents by ID."""
        id_set = set(doc_ids)
        with self._lock:
            keep = [(i, t, m) for i, t, m in zip(self._ids, self._texts, self._metas)
                    if i not in id_set]
            if len(keep) == len(self._ids):
                return  # nothing to delete
            self._ids = [k[0] for k in keep]
            self._texts = [k[1] for k in keep]
            self._metas = [k[2] for k in keep]
            self._rebuild()

    # -- Search --------------------------------------------------------------

    def search(
        self,
        query: str,
        top_k: int = 30,
        filters: dict | None = None,
    ) -> list[Bm25Hit]:
        """Search BM25 index, optionally filtering by metadata.

        Args:
            query: Search query text
            top_k: Number of results to return
            filters: Optional metadata filters (account_id, owner_space, etc.)

        Returns:
            List of Bm25Hit sorted by BM25 score descending
        """
        with self._lock:
            if self._bm25 is None or not self._ids:
                return []

            tokenized_query = _tokenize(query)
            if not tokenized_query:
                return []

            scores = self._bm25.get_scores(tokenized_query)

        # Build results with filtering
        results: list[Bm25Hit] = []
        for i, score in enumerate(scores):
            if score <= 0:
                continue
            meta = self._metas[i]
            if filters and not _matches_filters(meta, filters):
                continue
            results.append(Bm25Hit(
                id=self._ids[i],
                text=self._texts[i],
                score=float(score),
                metadata=meta,
            ))

        results.sort(key=lambda h: h.score, reverse=True)
        return results[:top_k]

    @property
    def doc_count(self) -> int:
        return len(self._ids)


def _matches_filters(metadata: dict, filters: dict) -> bool:
    """Check if metadata matches all filter criteria."""
    for key, expected in filters.items():
        val = metadata.get(key)
        if val is None:
            return False
        if isinstance(expected, list):
            if val not in expected:
                return False
        elif val != expected:
            return False
    return True