"""Additional tests for retrieval.hierarchical_searcher."""

from __future__ import annotations

from datetime import datetime, timezone

from core.models import RetrievalConfig, SeedHit, SeedResult, TypedQuery
from retrieval.hierarchical_searcher import HierarchicalSearcher
from tests.unit.retrieval.conftest import make_ctx


class FakeVectorIndex:
    def __init__(self, mapping: dict[str, list[SeedHit]]) -> None:
        self.mapping = mapping
        self.calls: list[tuple[str, dict, int]] = []

    def search_children(self, parent_uri, query_vector, filters, top_k):
        self.calls.append((parent_uri, filters, top_k))
        return self.mapping.get(parent_uri, [])


def test_expand_walks_hierarchy_and_returns_sorted_leaf_hits():
    mapping = {
        "ctx://acct-1/users/u1/memories/preferences": [
            SeedHit(
                uri="ctx://acct-1/users/u1/memories/preferences/coffee",
                score=0.9,
                level=2,
                parent_uri="ctx://acct-1/users/u1/memories/preferences",
                category="preferences",
                owner_space="user:u1",
                abstract="coffee",
                has_overview=True,
                has_content=True,
            ),
            SeedHit(
                uri="ctx://acct-1/users/u1/memories/preferences/nested",
                score=0.8,
                level=1,
                parent_uri="ctx://acct-1/users/u1/memories/preferences",
                category="preferences",
                owner_space="user:u1",
            ),
            SeedHit(
                uri="ctx://acct-1/users/u1/memories/preferences/ignore",
                score=0.1,
                level=2,
                parent_uri="ctx://acct-1/users/u1/memories/preferences",
                category="preferences",
                owner_space="user:u1",
            ),
        ],
        "ctx://acct-1/users/u1/memories/preferences/nested": [
            SeedHit(
                uri="ctx://acct-1/users/u1/memories/preferences/nested/tea",
                score=0.95,
                level=2,
                parent_uri="ctx://acct-1/users/u1/memories/preferences/nested",
                category="preferences",
                owner_space="user:u1",
                abstract="tea",
                has_content=True,
            )
        ],
    }
    vector_index = FakeVectorIndex(mapping)
    searcher = HierarchicalSearcher(
        vector_index,
        RetrievalConfig(
            max_convergence_rounds=2,
            score_propagation_alpha=0.5,
            hotness_alpha=0.0,
            default_score_threshold=0.5,
        ),
    )
    seed_result = SeedResult(
        query_vector=[0.1, 0.2],
        initial_candidates=[
            SeedHit(
                uri="ctx://acct-1/users/u1/memories/profile",
                score=0.7,
                level=2,
                category="profile",
                owner_space="user:u1",
                abstract="profile",
            )
        ],
        starting_points=[
            SeedHit(
                uri="ctx://acct-1/users/u1/memories/preferences",
                score=0.8,
                level=1,
                parent_uri="ctx://acct-1/users/u1/memories",
            )
        ],
    )

    results = searcher.expand(
        TypedQuery(text="coffee", context_type="MEMORY", categories=["preferences"], owner_space="user:u1"),
        seed_result,
        make_ctx(),
        limit=3,
        score_threshold=0.5,
    )

    uris = [hit.uri for hit in results]
    assert "ctx://acct-1/users/u1/memories/preferences/ignore" not in uris
    assert "ctx://acct-1/users/u1/memories/preferences/coffee" in uris
    assert "ctx://acct-1/users/u1/memories/preferences/nested/tea" in uris
    assert results[0].uri == "ctx://acct-1/users/u1/memories/preferences/nested/tea"
    assert vector_index.calls[0][1] == {
        "account_id": "acct-1",
        "context_type": "MEMORY",
        "owner_space": "user:u1",
    }


def test_recursive_search_stops_after_convergence():
    repeated_leaf = SeedHit(
        uri="ctx://acct-1/users/u1/memories/preferences/coffee",
        score=0.9,
        level=2,
        parent_uri="ctx://acct-1/users/u1/memories/preferences",
    )
    vector_index = FakeVectorIndex(
        {
            "ctx://acct-1/users/u1/memories/preferences": [repeated_leaf],
            "ctx://acct-1/users/u1/memories/preferences_alt": [repeated_leaf],
        }
    )
    searcher = HierarchicalSearcher(
        vector_index,
        RetrievalConfig(max_convergence_rounds=1, hotness_alpha=0.0),
    )

    results = searcher._recursive_search(
        query_vector=[0.1],
        starting_points=[
            SeedHit(uri="ctx://acct-1/users/u1/memories/preferences", score=0.8, level=1),
            SeedHit(uri="ctx://acct-1/users/u1/memories/preferences_alt", score=0.7, level=1),
        ],
        initial_candidates=[],
        typed_query=TypedQuery(text="coffee", context_type="MEMORY", categories=[]),
        ctx=make_ctx(),
        limit=1,
        threshold=None,
    )

    assert len(results) == 1
    assert len(vector_index.calls) == 2


def test_convert_results_blends_hotness_and_tolerates_invalid_timestamps():
    searcher = HierarchicalSearcher(
        FakeVectorIndex({}),
        RetrievalConfig(hotness_alpha=0.5, hotness_half_life_days=7.0),
    )

    results = searcher._convert_results(
        [
            {
                "uri": "one",
                "level": 2,
                "_final_score": 0.4,
                "active_count": 10,
                "updated_at": datetime.now(timezone.utc),
            },
            {
                "uri": "two",
                "level": 2,
                "_final_score": 0.9,
                "active_count": 0,
                "updated_at": "not-a-date",
            },
        ],
        limit=2,
    )

    assert {hit.uri for hit in results} == {"one", "two"}
    assert results[0].score >= results[1].score