"""Lightweight session topic detection."""
from __future__ import annotations
from dataclasses import dataclass
from math import sqrt
from typing import Any
from session.models import SessionMessage
@dataclass
class TopicDetection:
label: str
confidence: float
message_count: int
changed: bool = False
class TopicDetector:
def __init__(self, embedder: Any | None = None, similarity_threshold: float = 0.80):
self._embedder = embedder
self._similarity_threshold = similarity_threshold
def detect(
self,
messages: list[SessionMessage],
previous: TopicDetection | None = None,
) -> TopicDetection:
user_messages = [m for m in messages if m.role == "user" and m.content.strip()]
if not user_messages:
result = TopicDetection(label="general", confidence=0.0, message_count=0)
elif self._embedder is not None and len(user_messages) > 1:
result = self._detect_with_embeddings(user_messages)
else:
result = self._detect_with_keywords(user_messages)
if previous is not None:
result.changed = previous.label != result.label
return result
def _detect_with_keywords(self, messages: list[SessionMessage]) -> TopicDetection:
label_counts: dict[str, int] = {
"debugging": 0,
"planning": 0,
"code_review": 0,
}
label_scores: dict[str, int] = {
"debugging": 0,
"planning": 0,
"code_review": 0,
}
for message in messages:
label, score = self._classify_text(message.content)
if label in label_counts and score > 0:
label_counts[label] += 1
label_scores[label] += score
label = self._pick_label(label_counts, label_scores)
if label == "general":
generic_topic = self._detect_generic_keyword_topic(messages)
if generic_topic is not None:
return generic_topic
return TopicDetection(label="general", confidence=0.1, message_count=len(messages))
message_count = label_counts[label]
confidence = max(0.1, min(1.0, message_count / max(1, len(messages))))
return TopicDetection(label=label, confidence=confidence, message_count=message_count)
def _detect_with_embeddings(self, messages: list[SessionMessage]) -> TopicDetection:
texts = [message.content for message in messages]
try:
embeddings = self._embedder.embed_texts(texts)
except Exception:
return self._detect_with_keywords(messages)
if not isinstance(embeddings, list) or len(embeddings) != len(messages):
return self._detect_with_keywords(messages)
valid_embeddings = [
vector for vector in embeddings
if isinstance(vector, list) and vector
]
if len(valid_embeddings) != len(messages):
return self._detect_with_keywords(messages)
dims = len(valid_embeddings[0])
if any(len(vector) != dims for vector in valid_embeddings):
return self._detect_with_keywords(messages)
clusters: list[list[int]] = []
centroids: list[list[float]] = []
for idx, vector in enumerate(embeddings):
if not isinstance(vector, list) or not vector:
continue
best_idx = -1
best_score = -1.0
for cluster_idx, centroid in enumerate(centroids):
score = self._cosine_similarity(vector, centroid)
if score > best_score:
best_idx = cluster_idx
best_score = score
if best_idx >= 0 and best_score >= self._similarity_threshold:
clusters[best_idx].append(idx)
centroids[best_idx] = self._centroid([
embeddings[i] for i in clusters[best_idx]
])
else:
clusters.append([idx])
centroids.append(vector)
if not clusters:
return self._detect_with_keywords(messages)
best_cluster = max(clusters, key=lambda c: (len(c), max(c)))
cluster_messages = [messages[i] for i in best_cluster]
label = self._label_cluster(cluster_messages)
confidence = max(0.1, min(1.0, len(best_cluster) / max(1, len(messages))))
return TopicDetection(
label=label,
confidence=confidence,
message_count=len(best_cluster),
)
def _label_cluster(self, messages: list[SessionMessage]) -> str:
keyword_topic = self._detect_with_keywords(messages)
if keyword_topic.label != "general":
return keyword_topic.label
counts = self._generic_keyword_message_counts(messages)
if not counts:
return "general"
return sorted(counts.items(), key=lambda item: (-item[1], item[0]))[0][0]
def _detect_generic_keyword_topic(
self,
messages: list[SessionMessage],
) -> TopicDetection | None:
counts = self._generic_keyword_message_counts(messages)
if not counts:
return None
label, message_count = sorted(
counts.items(),
key=lambda item: (-item[1], item[0]),
)[0]
if message_count < 2:
return None
confidence = max(0.1, min(1.0, message_count / max(1, len(messages))))
return TopicDetection(label=label, confidence=confidence, message_count=message_count)
@staticmethod
def _generic_keyword_message_counts(messages: list[SessionMessage]) -> dict[str, int]:
stopwords = {
"the", "and", "for", "with", "this", "that", "please", "can",
"you", "how", "what", "why", "when", "from", "into", "about",
}
counts: dict[str, int] = {}
for message in messages:
message_words: set[str] = set()
for raw in message.content.lower().replace("_", " ").split():
word = "".join(ch for ch in raw if ch.isalnum())
if len(word) < 4 or word in stopwords:
continue
message_words.add(word)
for word in message_words:
counts[word] = counts.get(word, 0) + 1
return counts
def _classify_text(self, text: str) -> tuple[str, int]:
lowered = (text or "").lower()
labels = [
("debugging", ("debug", "traceback", "failing", "failure", "bug", "fix")),
("planning", ("plan", "prd", "requirements", "implementation plan")),
("code_review", ("review", "risk", "regression", "finding")),
]
best_label = "general"
best_score = 0
for label, keywords in labels:
score = sum(1 for keyword in keywords if keyword in lowered)
if score > best_score:
best_label = label
best_score = score
return best_label, best_score
@staticmethod
def _pick_label(label_counts: dict[str, int], label_scores: dict[str, int]) -> str:
best_label = "general"
best_count = 0
best_score = 0
for label in ("debugging", "planning", "code_review"):
count = label_counts.get(label, 0)
score = label_scores.get(label, 0)
if count > best_count or (count == best_count and score > best_score):
best_label = label
best_count = count
best_score = score
return best_label
@staticmethod
def _centroid(vectors: list[list[float]]) -> list[float]:
if not vectors:
return []
dims = len(vectors[0])
return [
sum(vector[idx] for vector in vectors) / len(vectors)
for idx in range(dims)
]
@staticmethod
def _cosine_similarity(left: list[float], right: list[float]) -> float:
if not left or not right:
return 0.0
dot = sum(l * r for l, r in zip(left, right))
left_norm = sqrt(sum(value * value for value in left))
right_norm = sqrt(sum(value * value for value in right))
if not left_norm or not right_norm:
return 0.0
return dot / (left_norm * right_norm)