"""Stage 2 — Seed retriever.

Performs a global vector search across all levels (L0/L2), optionally
augmented with BM25 keyword search via Vector-Anchored Fusion.
Splits results into starting_points (L0) and initial_candidates (L2).
"""

from __future__ import annotations

import logging
from typing import Any

from core.enums import ContextType
from core.interfaces import VectorIndex, Embedder
from core.models import (
    RetrievalConfig,
    RequestContext,
    RetrieverMode,
    SeedResult,
    TypedQuery,
    SeedHit,
)

logger = logging.getLogger(__name__)


class SeedRetriever:
    """Global L0+L2 search with optional BM25 hybrid fusion."""

    def __init__(
        self,
        vector_index: VectorIndex,
        embedder: Embedder,
        config: RetrievalConfig | None = None,
        bm25_index: Any | None = None,
    ) -> None:
        self.vector_index = vector_index
        self.embedder = embedder
        self.config = config or RetrievalConfig()
        self.bm25_index = bm25_index

    def search(
        self,
        typed_query: TypedQuery,
        ctx: RequestContext,
        *,
        mode: str = RetrieverMode.QUICK,
    ) -> SeedResult:
        query_vector = self.embedder.embed_texts([typed_query.text])[0]
        root_uris = self._get_root_uris(typed_query.context_type, ctx)

        global_results = self._global_vector_search(query_vector, typed_query, ctx)

        # Hybrid: merge BM25 results via Vector-Anchored Fusion
        if self.bm25_index and self.bm25_index.doc_count > 0:
            bm25_results = self._bm25_search(typed_query, ctx)
            if bm25_results:
                global_results = self._fuse_results(global_results, bm25_results)

        starting_points, initial_candidates = self._merge_starting_points(
            root_uris, global_results,
        )

        return SeedResult(
            starting_points=starting_points,
            initial_candidates=initial_candidates,
            query_vector=query_vector,
            root_uris=root_uris,
        )

    @staticmethod
    def _get_root_uris(context_type: str | None, ctx: RequestContext) -> list[str]:
        # SECURITY: URI format per CLAUDE.md §1:
        # - ctx://{account}/users/{user}/memories/
        # - ctx://{account}/agents/{agent}/memories/
        # - ctx://{account}/agents/{agent}/skills/
        account = ctx.account_id
        visible_spaces = list(getattr(ctx, "visible_owner_spaces", ()) or [])
        visible_agent_ids = sorted(
            {
                space.split(":", 1)[1]
                for space in visible_spaces
                if isinstance(space, str) and space.startswith("agent:") and space.split(":", 1)[1]
            }
        )
        if not visible_agent_ids and ctx.agent_id:
            visible_agent_ids = [ctx.agent_id]
        if not context_type:
            roots = [
                f"ctx://{account}/users/{ctx.user_id}/memories/",
                f"ctx://{account}/resources/",
            ]
            roots.extend(f"ctx://{account}/agents/{agent_id}/memories/" for agent_id in visible_agent_ids)
            roots.extend(f"ctx://{account}/agents/{agent_id}/skills/" for agent_id in visible_agent_ids)
            return roots
        ct = context_type.upper()
        if ct == ContextType.MEMORY.value:
            roots = [f"ctx://{account}/users/{ctx.user_id}/memories/"]
            roots.extend(f"ctx://{account}/agents/{agent_id}/memories/" for agent_id in visible_agent_ids)
            return roots
        if ct == ContextType.RESOURCE.value:
            return [f"ctx://{account}/resources/"]
        if ct == ContextType.SKILL.value:
            return [f"ctx://{account}/agents/{agent_id}/skills/" for agent_id in visible_agent_ids]
        return []

    def _global_vector_search(
        self,
        query_vector: list[float],
        typed_query: TypedQuery,
        ctx: RequestContext,
    ) -> list[SeedHit]:
        filters: dict[str, Any] = {
            "level": [0, 1, 2],  # Include directory L0/L1 for hierarchical expansion
            "account_id": ctx.account_id,
        }
        if typed_query.context_type:
            filters["context_type"] = typed_query.context_type
        if typed_query.owner_space:
            filters["owner_space"] = typed_query.owner_space
        if typed_query.categories:
            filters["category"] = typed_query.categories

        return self.vector_index.search_by_vector(
            query_vector=query_vector,
            filters=filters,
            top_k=self.config.global_search_topk,
        )

    def _bm25_search(
        self,
        typed_query: TypedQuery,
        ctx: RequestContext,
    ) -> list[SeedHit]:
        """BM25 keyword search with metadata filtering."""
        filters: dict[str, Any] = {
            "account_id": ctx.account_id,
        }
        if typed_query.context_type:
            filters["context_type"] = typed_query.context_type
        if typed_query.owner_space:
            filters["owner_space"] = typed_query.owner_space
        if typed_query.categories:
            filters["category"] = typed_query.categories

        bm25_hits = self.bm25_index.search(
            query=typed_query.text,
            top_k=self.config.global_search_topk,
            filters=filters,
        )
        hits: list[SeedHit] = []
        for h in bm25_hits:
            meta = h.metadata
            hits.append(SeedHit(
                uri=meta.get("uri", ""),
                score=0.0,  # score set by _fuse_results
                level=int(meta.get("level", 2)),
                parent_uri=meta.get("parent_uri"),
                context_type=meta.get("context_type", ""),
                category=meta.get("category", ""),
                owner_space=meta.get("owner_space", ""),
                abstract=meta.get("abstract", ""),
                has_overview=bool(meta.get("has_overview")),
                has_content=bool(meta.get("has_content")),
                active_count=int(meta.get("active_count", 0)),
                updated_at=meta.get("updated_at"),
                metadata={"_bm25_raw_score": h.score},
            ))
        return hits

    @staticmethod
    def _fuse_results(
        vector_hits: list[SeedHit],
        bm25_hits: list[SeedHit],
        alpha: float = 0.7,
        saturation_k: float = 5.0,
    ) -> list[SeedHit]:
        """Vector-Anchored Fusion: combine vector and BM25 results.

        Formula:
            sat_bm25 = raw_bm25 / (raw_bm25 + saturation_k)
            final_score = alpha * vec_score + (1 - alpha) * sat_bm25

        Args:
            vector_hits: Results from vector search (score = cosine similarity)
            bm25_hits: Results from BM25 search (score = raw BM25 score)
            alpha: Weight for vector score (default 0.7 = 70% vector)
            saturation_k: BM25 saturation constant (default 5.0)
        """
        # Index vector hits by URI
        fused: dict[str, SeedHit] = {}
        vec_scores: dict[str, float] = {}
        bm25_scores: dict[str, float] = {}

        for hit in vector_hits:
            fused[hit.uri] = hit
            vec_scores[hit.uri] = hit.score

        for hit in bm25_hits:
            bm25_scores[hit.uri] = hit.metadata.get("_bm25_raw_score", 0.0)
            if hit.uri not in fused:
                fused[hit.uri] = hit

        if not fused:
            return []

        # Compute floor defaults (minimum score from each path)
        min_vec = min(vec_scores.values()) if vec_scores else 0.0
        min_bm25_sat = min(
            s / (s + saturation_k) for s in bm25_scores.values()
        ) if bm25_scores else 0.0

        # Compute fused scores
        results: list[SeedHit] = []
        for uri, hit in fused.items():
            vec_s = vec_scores.get(uri, min_vec)
            raw_bm25 = bm25_scores.get(uri, 0.0)
            sat_bm25 = raw_bm25 / (raw_bm25 + saturation_k) if raw_bm25 > 0 else min_bm25_sat
            fused_score = alpha * vec_s + (1 - alpha) * sat_bm25

            results.append(SeedHit(
                uri=hit.uri,
                score=fused_score,
                level=hit.level,
                parent_uri=hit.parent_uri,
                context_type=hit.context_type,
                category=hit.category,
                owner_space=hit.owner_space,
                abstract=hit.abstract,
                has_overview=hit.has_overview,
                has_content=hit.has_content,
                active_count=hit.active_count,
                updated_at=hit.updated_at,
                metadata=hit.metadata,
            ))

        results.sort(key=lambda h: h.score, reverse=True)
        return results

    @staticmethod
    def _merge_starting_points(
        root_uris: list[str],
        global_results: list[SeedHit],
    ) -> tuple[list[SeedHit], list[SeedHit]]:
        initial_candidates: list[SeedHit] = []
        points: list[SeedHit] = []
        seen: set[str] = set()

        for hit in global_results:
            if hit.level == 2:
                initial_candidates.append(hit)
            else:
                if hit.uri not in seen:
                    points.append(hit)
                    seen.add(hit.uri)

        for uri in root_uris:
            if uri not in seen:
                points.append(SeedHit(uri=uri, score=0.0, level=0))
                seen.add(uri)

        return points, initial_candidates