from session.models import SessionMessage
from session.topic_detector import TopicDetector


class _Embedder:
    def embed_texts(self, texts):
        vectors = []
        for text in texts:
            if "debug" in text.lower() or "traceback" in text.lower():
                vectors.append([1.0, 0.0])
            else:
                vectors.append([0.0, 1.0])
        return vectors


class _ClusterEmbedder:
    def embed_texts(self, texts):
        mapping = {
            "latency p95 dashboard": [1.0, 0.0],
            "tail latency percentile": [0.95, 0.05],
            "billing invoice export": [0.0, 1.0],
        }
        return [mapping[text] for text in texts]


class _MalformedEmbedder:
    def embed_texts(self, texts):
        return [[1.0, 0.0], [0.95], [0.0, 1.0]]


def _msg(content: str) -> SessionMessage:
    return SessionMessage(id=content[:8], role="user", content=content)


def test_detect_topic_uses_keyword_fallback_without_embedder():
    detector = TopicDetector()

    topic = detector.detect([_msg("Please debug this failing pytest traceback")])

    assert topic.label == "debugging"
    assert topic.confidence > 0


def test_detect_topic_uses_embedding_cluster_when_embedder_available():
    detector = TopicDetector(embedder=_Embedder())

    topic = detector.detect([
        _msg("Debug the traceback"),
        _msg("Fix this pytest failure"),
        _msg("Review the architecture plan"),
    ])

    assert topic.label == "debugging"
    assert topic.message_count == 2


def test_detect_topic_switch_reports_changed_label():
    detector = TopicDetector()
    first = detector.detect([_msg("Please debug this failing test")])
    second = detector.detect([_msg("Write an implementation plan")], previous=first)

    assert first.label == "debugging"
    assert second.label == "planning"
    assert second.changed is True


def test_detect_topic_clusters_embeddings_before_labeling():
    detector = TopicDetector(embedder=_ClusterEmbedder(), similarity_threshold=0.8)

    topic = detector.detect([
        _msg("latency p95 dashboard"),
        _msg("tail latency percentile"),
        _msg("billing invoice export"),
    ])

    assert topic.label == "latency"
    assert topic.message_count == 2
    assert topic.confidence > 0.5


def test_detect_topic_falls_back_for_mixed_dimension_embeddings():
    detector = TopicDetector(embedder=_MalformedEmbedder(), similarity_threshold=0.8)

    topic = detector.detect([
        _msg("Please debug this traceback"),
        _msg("Fix this failing test"),
        _msg("Write an implementation plan"),
    ])

    assert topic.label == "debugging"
    assert topic.message_count == 2


def test_detect_topic_uses_generic_keyword_label_for_non_code_topics():
    detector = TopicDetector()

    topic = detector.detect([
        _msg("billing invoice export"),
        _msg("billing payment reconciliation"),
    ])

    assert topic.label == "billing"
    assert topic.message_count == 2


def test_detect_topic_keeps_general_for_single_generic_message():
    detector = TopicDetector()

    topic = detector.detect([_msg("status update invoice export")])

    assert topic.label == "general"
    assert topic.message_count == 1