"""Stage 2 — Seed retriever.
Performs a global vector search across all levels (L0/L2), optionally
augmented with BM25 keyword search via Vector-Anchored Fusion.
Splits results into starting_points (L0) and initial_candidates (L2).
"""
from __future__ import annotations
import logging
from typing import Any
from core.enums import ContextType
from core.interfaces import VectorIndex, Embedder
from core.models import (
RetrievalConfig,
RequestContext,
RetrieverMode,
SeedResult,
TypedQuery,
SeedHit,
)
logger = logging.getLogger(__name__)
class SeedRetriever:
"""Global L0+L2 search with optional BM25 hybrid fusion."""
def __init__(
self,
vector_index: VectorIndex,
embedder: Embedder,
config: RetrievalConfig | None = None,
bm25_index: Any | None = None,
) -> None:
self.vector_index = vector_index
self.embedder = embedder
self.config = config or RetrievalConfig()
self.bm25_index = bm25_index
def search(
self,
typed_query: TypedQuery,
ctx: RequestContext,
*,
mode: str = RetrieverMode.QUICK,
) -> SeedResult:
query_vector = self.embedder.embed_texts([typed_query.text])[0]
root_uris = self._get_root_uris(typed_query.context_type, ctx)
global_results = self._global_vector_search(query_vector, typed_query, ctx)
if self.bm25_index and self.bm25_index.doc_count > 0:
bm25_results = self._bm25_search(typed_query, ctx)
if bm25_results:
global_results = self._fuse_results(global_results, bm25_results)
starting_points, initial_candidates = self._merge_starting_points(
root_uris, global_results,
)
return SeedResult(
starting_points=starting_points,
initial_candidates=initial_candidates,
query_vector=query_vector,
root_uris=root_uris,
)
@staticmethod
def _get_root_uris(context_type: str | None, ctx: RequestContext) -> list[str]:
account = ctx.account_id
visible_spaces = list(getattr(ctx, "visible_owner_spaces", ()) or [])
visible_agent_ids = sorted(
{
space.split(":", 1)[1]
for space in visible_spaces
if isinstance(space, str) and space.startswith("agent:") and space.split(":", 1)[1]
}
)
if not visible_agent_ids and ctx.agent_id:
visible_agent_ids = [ctx.agent_id]
if not context_type:
roots = [
f"ctx://{account}/users/{ctx.user_id}/memories/",
f"ctx://{account}/resources/",
]
roots.extend(f"ctx://{account}/agents/{agent_id}/memories/" for agent_id in visible_agent_ids)
roots.extend(f"ctx://{account}/agents/{agent_id}/skills/" for agent_id in visible_agent_ids)
return roots
ct = context_type.upper()
if ct == ContextType.MEMORY.value:
roots = [f"ctx://{account}/users/{ctx.user_id}/memories/"]
roots.extend(f"ctx://{account}/agents/{agent_id}/memories/" for agent_id in visible_agent_ids)
return roots
if ct == ContextType.RESOURCE.value:
return [f"ctx://{account}/resources/"]
if ct == ContextType.SKILL.value:
return [f"ctx://{account}/agents/{agent_id}/skills/" for agent_id in visible_agent_ids]
return []
def _global_vector_search(
self,
query_vector: list[float],
typed_query: TypedQuery,
ctx: RequestContext,
) -> list[SeedHit]:
filters: dict[str, Any] = {
"level": [0, 1, 2],
"account_id": ctx.account_id,
}
if typed_query.context_type:
filters["context_type"] = typed_query.context_type
if typed_query.owner_space:
filters["owner_space"] = typed_query.owner_space
if typed_query.categories:
filters["category"] = typed_query.categories
return self.vector_index.search_by_vector(
query_vector=query_vector,
filters=filters,
top_k=self.config.global_search_topk,
)
def _bm25_search(
self,
typed_query: TypedQuery,
ctx: RequestContext,
) -> list[SeedHit]:
"""BM25 keyword search with metadata filtering."""
filters: dict[str, Any] = {
"account_id": ctx.account_id,
}
if typed_query.context_type:
filters["context_type"] = typed_query.context_type
if typed_query.owner_space:
filters["owner_space"] = typed_query.owner_space
if typed_query.categories:
filters["category"] = typed_query.categories
bm25_hits = self.bm25_index.search(
query=typed_query.text,
top_k=self.config.global_search_topk,
filters=filters,
)
hits: list[SeedHit] = []
for h in bm25_hits:
meta = h.metadata
hits.append(SeedHit(
uri=meta.get("uri", ""),
score=0.0,
level=int(meta.get("level", 2)),
parent_uri=meta.get("parent_uri"),
context_type=meta.get("context_type", ""),
category=meta.get("category", ""),
owner_space=meta.get("owner_space", ""),
abstract=meta.get("abstract", ""),
has_overview=bool(meta.get("has_overview")),
has_content=bool(meta.get("has_content")),
active_count=int(meta.get("active_count", 0)),
updated_at=meta.get("updated_at"),
metadata={"_bm25_raw_score": h.score},
))
return hits
@staticmethod
def _fuse_results(
vector_hits: list[SeedHit],
bm25_hits: list[SeedHit],
alpha: float = 0.7,
saturation_k: float = 5.0,
) -> list[SeedHit]:
"""Vector-Anchored Fusion: combine vector and BM25 results.
Formula:
sat_bm25 = raw_bm25 / (raw_bm25 + saturation_k)
final_score = alpha * vec_score + (1 - alpha) * sat_bm25
Args:
vector_hits: Results from vector search (score = cosine similarity)
bm25_hits: Results from BM25 search (score = raw BM25 score)
alpha: Weight for vector score (default 0.7 = 70% vector)
saturation_k: BM25 saturation constant (default 5.0)
"""
fused: dict[str, SeedHit] = {}
vec_scores: dict[str, float] = {}
bm25_scores: dict[str, float] = {}
for hit in vector_hits:
fused[hit.uri] = hit
vec_scores[hit.uri] = hit.score
for hit in bm25_hits:
bm25_scores[hit.uri] = hit.metadata.get("_bm25_raw_score", 0.0)
if hit.uri not in fused:
fused[hit.uri] = hit
if not fused:
return []
min_vec = min(vec_scores.values()) if vec_scores else 0.0
min_bm25_sat = min(
s / (s + saturation_k) for s in bm25_scores.values()
) if bm25_scores else 0.0
results: list[SeedHit] = []
for uri, hit in fused.items():
vec_s = vec_scores.get(uri, min_vec)
raw_bm25 = bm25_scores.get(uri, 0.0)
sat_bm25 = raw_bm25 / (raw_bm25 + saturation_k) if raw_bm25 > 0 else min_bm25_sat
fused_score = alpha * vec_s + (1 - alpha) * sat_bm25
results.append(SeedHit(
uri=hit.uri,
score=fused_score,
level=hit.level,
parent_uri=hit.parent_uri,
context_type=hit.context_type,
category=hit.category,
owner_space=hit.owner_space,
abstract=hit.abstract,
has_overview=hit.has_overview,
has_content=hit.has_content,
active_count=hit.active_count,
updated_at=hit.updated_at,
metadata=hit.metadata,
))
results.sort(key=lambda h: h.score, reverse=True)
return results
@staticmethod
def _merge_starting_points(
root_uris: list[str],
global_results: list[SeedHit],
) -> tuple[list[SeedHit], list[SeedHit]]:
initial_candidates: list[SeedHit] = []
points: list[SeedHit] = []
seen: set[str] = set()
for hit in global_results:
if hit.level == 2:
initial_candidates.append(hit)
else:
if hit.uri not in seen:
points.append(hit)
seen.add(hit.uri)
for uri in root_uris:
if uri not in seen:
points.append(SeedHit(uri=uri, score=0.0, level=0))
seen.add(uri)
return points, initial_candidates