"""Stage 3 — Hierarchical recursive searcher.

Priority-queue driven recursive expansion:
  - L0/L1 nodes are pushed back into the queue
  - L2 nodes are terminal hits
  - Score propagation: final = alpha * child + (1-alpha) * parent
  - Convergence detection stops after top-k unchanged for N rounds
"""

from __future__ import annotations

import heapq
import logging
from datetime import datetime, timezone
from typing import Any

from core.interfaces import VectorIndex
from core.models import (
    LeafHit,
    RetrievalConfig,
    RequestContext,
    RetrieverMode,
    SeedResult,
    TypedQuery,
    SeedHit,
)
from retrieval.hotness import hotness_score

logger = logging.getLogger(__name__)


class HierarchicalSearcher:
    def __init__(
        self,
        vector_index: VectorIndex,
        config: RetrievalConfig | None = None,
    ) -> None:
        self.vector_index = vector_index
        self.cfg = config or RetrievalConfig()

    def expand(
        self,
        typed_query: TypedQuery,
        seed_result: SeedResult,
        ctx: RequestContext,
        *,
        limit: int = 5,
        mode: str = RetrieverMode.THINKING,
        score_threshold: float | None = None,
    ) -> list[LeafHit]:
        candidates = self._recursive_search(
            query_vector=seed_result.query_vector,
            starting_points=seed_result.starting_points,
            initial_candidates=seed_result.initial_candidates,
            typed_query=typed_query,
            ctx=ctx,
            limit=limit,
            threshold=score_threshold,
        )
        return self._convert_results(candidates, limit)

    def _recursive_search(
        self,
        query_vector: list[float],
        starting_points: list[SeedHit],
        initial_candidates: list[SeedHit],
        typed_query: TypedQuery,
        ctx: RequestContext,
        limit: int,
        threshold: float | None,
    ) -> list[dict[str, Any]]:
        effective_threshold = threshold if threshold is not None else self.cfg.default_score_threshold
        alpha = self.cfg.score_propagation_alpha
        collected_by_uri: dict[str, dict[str, Any]] = {}
        dir_queue: list[tuple[float, str]] = []
        visited: set[str] = set()
        prev_topk_uris: set[str] = set()
        convergence_rounds = 0

        for hit in initial_candidates:
            if hit.level == 2 and hit.uri:
                collected_by_uri[hit.uri] = self._hit_to_dict(hit, hit.score)

        for hit in starting_points:
            heapq.heappush(dir_queue, (-hit.score, hit.uri))

        base_filters: dict[str, Any] = {"account_id": ctx.account_id}
        if typed_query.context_type:
            base_filters["context_type"] = typed_query.context_type
        if typed_query.owner_space:
            base_filters["owner_space"] = typed_query.owner_space

        while dir_queue:
            neg_score, current_uri = heapq.heappop(dir_queue)
            current_score = -neg_score
            if current_uri in visited:
                continue
            visited.add(current_uri)

            pre_filter_limit = max(limit * 2, 20)
            children = self.vector_index.search_children(
                parent_uri=current_uri,
                query_vector=query_vector,
                filters=base_filters,
                top_k=pre_filter_limit,
            )
            if not children:
                continue

            for child in children:
                final_score = (
                    alpha * child.score + (1 - alpha) * current_score
                    if current_score
                    else child.score
                )

                if final_score <= effective_threshold:
                    continue

                uri = child.uri
                prev = collected_by_uri.get(uri)
                if prev is None or final_score > prev.get("_final_score", 0):
                    collected_by_uri[uri] = self._hit_to_dict(child, final_score)

                if uri not in visited and child.level != 2:
                    heapq.heappush(dir_queue, (-final_score, uri))

            current_topk = sorted(
                collected_by_uri.values(),
                key=lambda x: x.get("_final_score", 0),
                reverse=True,
            )[:limit]
            current_topk_uris = {c["uri"] for c in current_topk}

            if current_topk_uris == prev_topk_uris and len(current_topk_uris) >= limit:
                convergence_rounds += 1
                if convergence_rounds >= self.cfg.max_convergence_rounds:
                    break
            else:
                convergence_rounds = 0
                prev_topk_uris = current_topk_uris

        return sorted(
            collected_by_uri.values(),
            key=lambda x: x.get("_final_score", 0),
            reverse=True,
        )[:limit]

    def _convert_results(self, candidates: list[dict[str, Any]], limit: int) -> list[LeafHit]:
        results: list[LeafHit] = []
        hotness_alpha = self.cfg.hotness_alpha

        for c in candidates:
            semantic_score: float = c.get("_final_score", 0.0)

            updated_at_raw = c.get("updated_at")
            updated_at_val: datetime | None = None
            if isinstance(updated_at_raw, str):
                try:
                    updated_at_val = datetime.fromisoformat(updated_at_raw)
                except (ValueError, TypeError):
                    pass
            elif isinstance(updated_at_raw, datetime):
                updated_at_val = updated_at_raw

            h_score = hotness_score(
                active_count=c.get("active_count", 0),
                updated_at=updated_at_val,
                half_life_days=self.cfg.hotness_half_life_days,
            )

            blended = (1 - hotness_alpha) * semantic_score + hotness_alpha * h_score

            results.append(
                LeafHit(
                    uri=c.get("uri", ""),
                    score=blended,
                    level=c.get("level", 2),
                    parent_uri=c.get("parent_uri"),
                    category=c.get("category", ""),
                    owner_space=c.get("owner_space", ""),
                    abstract=c.get("abstract", ""),
                    has_overview=c.get("has_overview", False),
                    has_content=c.get("has_content", False),
                    active_count=c.get("active_count", 0),
                    updated_at=c.get("updated_at"),
                )
            )

        results.sort(key=lambda x: x.score, reverse=True)
        return results[:limit]

    @staticmethod
    def _hit_to_dict(hit: SeedHit, final_score: float) -> dict[str, Any]:
        return {
            "uri": hit.uri,
            "level": hit.level,
            "parent_uri": hit.parent_uri,
            "category": hit.category,
            "owner_space": hit.owner_space,
            "abstract": hit.abstract,
            "has_overview": hit.has_overview,
            "has_content": hit.has_content,
            "active_count": hit.active_count,
            "updated_at": hit.updated_at,
            "_final_score": final_score,
        }