"""Stage 1 — Rule-based query planner.
Classifies a natural-language query into one or more TypedQuery objects
using keyword matching with a deterministic fallback.
"""
from __future__ import annotations
import logging
import re
from typing import Sequence
from core.enums import ContextType
from core.errors import ValidationError
from core.models import RetrievalConfig, RequestContext, TypedQuery
from retrieval.intent_classifier import RetrievalIntentClassifier
logger = logging.getLogger(__name__)
_MEMORY_PATTERNS: list[re.Pattern[str]] = [
re.compile(p, re.IGNORECASE)
for p in [
r"偏好|背景|身份|记忆|回顾|profile|preference|background|remember",
]
]
_SKILL_PATTERNS: list[re.Pattern[str]] = [
re.compile(p, re.IGNORECASE)
for p in [
r"怎么做|步骤|命令|流程|how\s*to|workflow|procedure|instruction|tutorial",
]
]
_RESOURCE_PATTERNS: list[re.Pattern[str]] = [
re.compile(p, re.IGNORECASE)
for p in [
r"文档|资料|规范|链接|document|spec|reference|link|paper|article",
]
]
def _match_any(text: str, patterns: Sequence[re.Pattern[str]]) -> bool:
return any(p.search(text) for p in patterns)
_NOISE_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"```[\s\S]*?```"),
re.compile(r"Sender\s*\(untrusted metadata\)\s*:\s*"),
re.compile(r"\[[\w\s:\-]+UTC\]\s*"),
]
def sanitize_query(raw: str) -> str:
"""Strip upstream metadata noise, keeping only user intent."""
text = raw
for pat in _NOISE_PATTERNS:
text = pat.sub("", text)
return text.strip()
class QueryPlanner:
def __init__(self, config: RetrievalConfig | None = None, llm=None) -> None:
self.config = config or RetrievalConfig()
self.intent_classifier = RetrievalIntentClassifier()
self._llm = llm
def plan(
self,
query: str,
ctx: RequestContext,
*,
session_archive: dict | None = None,
hints: dict | None = None,
categories: list[str] | None = None,
top_k: int | None = None,
) -> list[TypedQuery]:
query = sanitize_query(query or "")
if not query:
raise ValidationError("query", "query must not be empty")
effective_top_k = top_k or self.config.default_top_k
if effective_top_k > self.config.max_top_k:
raise ValidationError(
"top_k",
f"top_k={effective_top_k} exceeds maximum {self.config.max_top_k}",
)
sub_queries = self._decompose(query)
if len(sub_queries) > 1:
logger.info("[QueryPlanner] decomposed '%s' into %d sub-queries: %s", query, len(sub_queries), sub_queries)
return [
self._make_query(
sq, ContextType.MEMORY.value, ctx,
categories=categories, top_k=effective_top_k,
)
for sq in sub_queries
]
matched_types = self._classify(query)
if len(matched_types) == 1:
return [
self._make_query(
query, matched_types[0], ctx,
categories=categories, top_k=effective_top_k,
)
]
logger.info("[QueryPlanner] matched %d types, defaulting to MEMORY", len(matched_types))
return [
self._make_query(query, ContextType.MEMORY.value, ctx, categories=categories, top_k=effective_top_k)
]
@staticmethod
def _classify(text: str) -> list[str]:
hits: list[str] = []
if _match_any(text, _MEMORY_PATTERNS):
hits.append(ContextType.MEMORY.value)
if _match_any(text, _SKILL_PATTERNS):
hits.append(ContextType.SKILL.value)
if _match_any(text, _RESOURCE_PATTERNS):
hits.append(ContextType.RESOURCE.value)
return hits
def _decompose(self, query: str) -> list[str]:
"""Decompose a complex query into 2-4 sub-queries via LLM.
Returns [query] (single-element list) if decomposition is skipped
or fails — safe to call unconditionally.
"""
if self._llm is None or len(query.split()) < 5:
return [query]
try:
result = self._llm.complete_json(
(
"Decompose the following question into 2-4 independent search "
"sub-queries, each focusing on one aspect.\n\n"
f"Question: {query}"
),
schema={
"type": "object",
"properties": {
"sub_queries": {
"type": "array",
"items": {"type": "string"},
"minItems": 2,
"maxItems": 4,
}
},
"required": ["sub_queries"],
},
)
subs = result.get("sub_queries", [])
if isinstance(subs, list) and len(subs) >= 2 and all(isinstance(s, str) and s.strip() for s in subs):
return subs[:4]
except Exception:
logger.debug("[QueryPlanner] decomposition failed, using original query")
return [query]
def _make_query(
self,
text: str,
context_type: str,
ctx: RequestContext,
*,
categories: list[str] | None,
top_k: int,
) -> TypedQuery:
visible_spaces = list(ctx.visible_owner_spaces) if getattr(ctx, "visible_owner_spaces", ()) else []
if context_type == ContextType.SKILL.value:
visible_agent_spaces = [space for space in visible_spaces if str(space).startswith("agent:")]
agent_space = ctx.agent_space_name() if ctx.agent_id else ""
if agent_space and (not visible_spaces or agent_space in visible_spaces):
owner_space = agent_space
elif visible_agent_spaces:
owner_space = visible_agent_spaces
else:
owner_space = "__inaccessible__"
elif visible_spaces:
owner_space = visible_spaces
else:
owner_space = None
intent = self.intent_classifier.classify(text)
return TypedQuery(
text=text,
context_type=context_type,
categories=categories or [],
top_k=top_k,
account_id=ctx.account_id,
owner_space=owner_space,
intent=intent.value,
)