"""ChromaDB-backed Vector Index with persistence.
Data survives server restarts. Uses ChromaDB's embedding-free mode - vectors
come from the Embedder, not from ChromaDB's built-in models.
"""
import hashlib
import logging
from typing import Any
from core.interfaces import VectorIndex
from core.models import IndexRecord, SeedHit
logger = logging.getLogger(__name__)
class ChromaVectorIndex(VectorIndex):
"""Persistent vector index backed by ChromaDB."""
def __init__(
self,
collection_name: str = "contextengine",
persist_directory: str = ".chroma_data",
dimension: int = 1536,
):
import chromadb
self._dimension = dimension
self._client = chromadb.PersistentClient(path=persist_directory)
self._collection = self._client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"},
)
logger.info(
"ChromaVectorIndex ready: collection=%s, dir=%s, existing=%d records",
collection_name, persist_directory, self._collection.count(),
)
def upsert(self, records: list[IndexRecord]) -> None:
if not records:
return
ids = [r.id for r in records]
documents = [r.text for r in records]
metadatas = [self._record_to_meta(r) for r in records]
embeddings = []
for r in records:
real_emb = (r.metadata or {}).get("_embedding")
if real_emb and isinstance(real_emb, list):
embeddings.append(real_emb)
else:
embeddings.append(self._stable_vector(r.text))
self._collection.upsert(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
def delete(self, ids: list[str]) -> None:
if not ids:
return
self._collection.delete(ids=ids)
def search_by_vector(
self,
query_vector: list[float],
filters: dict[str, Any],
top_k: int,
) -> list[SeedHit]:
chroma_where = self._build_where(filters)
count = self._collection.count()
if count == 0:
return []
results = self._collection.query(
query_embeddings=[query_vector],
n_results=min(top_k, count),
where=chroma_where if chroma_where else None,
include=["documents", "metadatas", "distances"],
)
hits: list[SeedHit] = []
if not results or not results["ids"] or not results["ids"][0]:
return hits
for i in range(len(results["ids"][0])):
meta = results["metadatas"][0][i] if results["metadatas"] else {}
doc = results["documents"][0][i] if results["documents"] else ""
dist = results["distances"][0][i] if results["distances"] else 1.0
score = max(0.0, 1.0 - dist)
hits.append(SeedHit(
uri=meta.get("uri", ""),
score=score,
level=int(meta.get("level", 2)),
parent_uri=meta.get("parent_uri"),
context_type=meta.get("context_type", ""),
category=meta.get("category", ""),
owner_space=meta.get("owner_space", ""),
abstract=doc[:200] if doc else "",
has_overview=bool(meta.get("has_overview")),
has_content=bool(meta.get("has_content")),
active_count=int(meta.get("active_count", 0)),
updated_at=meta.get("updated_at"),
metadata={},
))
return hits
def search_children(
self,
parent_uri: str,
query_vector: list[float],
filters: dict[str, Any],
top_k: int,
) -> list[SeedHit]:
enhanced = dict(filters)
enhanced["parent_uri"] = parent_uri
return self.search_by_vector(query_vector, enhanced, top_k)
@staticmethod
def _record_to_meta(record: IndexRecord) -> dict[str, Any]:
meta = {
"uri": record.uri,
"level": record.level,
"account_id": record.filters.get("account_id", ""),
"owner_space": record.filters.get("owner_space", ""),
}
for key in ("context_type", "category", "parent_uri",
"has_overview", "has_content", "active_count", "updated_at",
"when", "who", "where", "routing_key"):
val = record.metadata.get(key) if record.metadata else None
if val is not None:
meta[key] = val
sanitized = {}
for k, v in meta.items():
if isinstance(v, bool):
sanitized[k] = v
elif isinstance(v, (int, float, str)):
sanitized[k] = v
elif v is not None:
sanitized[k] = str(v)
return sanitized
@staticmethod
def _build_where(filters: dict[str, Any]) -> dict | None:
conditions: list[dict] = []
for key, expected in filters.items():
if key == "level":
if isinstance(expected, list):
conditions.append({"level": {"$in": expected}})
else:
conditions.append({"level": expected})
elif key in ("account_id", "owner_space"):
if isinstance(expected, list):
conditions.append({key: {"$in": expected}})
else:
conditions.append({key: expected})
elif key in ("context_type", "category"):
if isinstance(expected, list):
conditions.append({key: {"$in": expected}})
else:
conditions.append({key: expected})
elif key == "parent_uri":
conditions.append({key: expected})
if not conditions:
return None
if len(conditions) == 1:
return conditions[0]
return {"$and": conditions}
def _stable_vector(self, text: str) -> list[float]:
import numpy as np
seed = int(hashlib.md5(text.encode()).hexdigest(), 16)
rng = np.random.RandomState(seed % (2**31))
vec = rng.randn(self._dimension).astype(np.float32)
norm = np.linalg.norm(vec)
if norm > 0:
vec /= norm
return vec.tolist()
def count(self) -> int:
return self._collection.count()