"""Candidate extraction pipeline.
Runs all CandidateExtractor implementations in parallel and aggregates results.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable
from core.models import RequestContext, CandidateMemory
from core.interfaces import CandidateExtractor
from core.logging_config import get_logger
from core.validation import validate_candidate, validate_attribution, ValidationError
logger = get_logger(__name__)
class CandidatePipeline:
"""Pipeline for parallel candidate extraction.
Runs all extractors concurrently and aggregates results.
"""
def __init__(self, extractors: list[CandidateExtractor] | None = None):
"""Initialize CandidatePipeline.
Args:
extractors: List of CandidateExtractor instances.
If None, must be set before calling extract().
"""
self._extractors = extractors or []
def set_extractors(self, extractors: list[CandidateExtractor]) -> None:
"""Set the extractors to use.
Args:
extractors: List of CandidateExtractor instances
"""
self._extractors = extractors
def extract(
self,
messages: list[dict],
ctx: RequestContext,
parallel: bool = True,
session_time=None,
session_summary: str = "",
tool_stats_text: str = "",
archive_id: str | None = None,
) -> list[CandidateMemory]:
"""Extract candidates from messages using all extractors.
Args:
messages: List of message dicts with "role" and "content"
ctx: RequestContext for this extraction
parallel: If True, run extractors in parallel (default: True)
session_time: Optional datetime for temporal resolution
session_summary: Optional summary of previously extracted content
tool_stats_text: Optional tool usage statistics text
archive_id: Optional archive_id for provenance tracking
Returns:
List of CandidateMemory from all extractors
"""
if parallel:
return self._extract_parallel(messages, ctx, session_time, session_summary, tool_stats_text, archive_id)
else:
return self._extract_serial(messages, ctx, session_time, session_summary, tool_stats_text, archive_id)
def _extract_parallel(
self,
messages: list[dict],
ctx: RequestContext,
session_time=None,
session_summary: str = "",
tool_stats_text: str = "",
archive_id: str | None = None,
) -> list[CandidateMemory]:
"""Run extractors in parallel using ThreadPoolExecutor.
Args:
messages: List of message dicts
ctx: RequestContext
session_time: Optional datetime for temporal resolution
session_summary: Optional summary of previously extracted content
tool_stats_text: Optional tool usage statistics text
archive_id: Optional archive_id for provenance tracking
Returns:
List of CandidateMemory from all extractors
"""
all_candidates = []
with ThreadPoolExecutor(max_workers=len(self._extractors)) as executor:
futures = {
executor.submit(
extractor.extract, messages, ctx, session_time,
session_summary, tool_stats_text, archive_id,
): extractor
for extractor in self._extractors
}
for future in as_completed(futures):
extractor = futures[future]
try:
candidates = future.result()
all_candidates.extend(candidates)
except Exception as e:
logger.error(
"Extractor %s failed: %s",
extractor.__class__.__name__, e,
exc_info=True
)
validated_candidates = []
for candidate in all_candidates:
try:
candidate = validate_attribution(candidate, user_id=ctx.user_id if ctx else None)
validated_candidates.append(validate_candidate(candidate))
except ValidationError as e:
logger.warning("Rejected invalid candidate: %s", e)
return validated_candidates
def _extract_serial(
self,
messages: list[dict],
ctx: RequestContext,
session_time=None,
session_summary: str = "",
tool_stats_text: str = "",
archive_id: str | None = None,
) -> list[CandidateMemory]:
"""Run extractors serially.
Args:
messages: List of message dicts
ctx: RequestContext
session_time: Optional datetime for temporal resolution
session_summary: Optional summary of previously extracted content
tool_stats_text: Optional tool usage statistics text
archive_id: Optional archive_id for provenance tracking
Returns:
List of CandidateMemory from all extractors
"""
all_candidates = []
for extractor in self._extractors:
try:
candidates = extractor.extract(
messages, ctx, session_time,
session_summary, tool_stats_text, archive_id,
)
all_candidates.extend(candidates)
except Exception as e:
logger.error(
"Extractor %s failed: %s",
extractor.__class__.__name__, e,
exc_info=True
)
validated_candidates = []
for candidate in all_candidates:
try:
candidate = validate_attribution(candidate, user_id=ctx.user_id if ctx else None)
validated_candidates.append(validate_candidate(candidate))
except ValidationError as e:
logger.warning("Rejected invalid candidate: %s", e)
return validated_candidates
def filter_by_confidence(
self,
candidates: list[CandidateMemory],
threshold: float = 0.5
) -> list[CandidateMemory]:
"""Filter candidates by confidence threshold.
Args:
candidates: List of CandidateMemory
threshold: Minimum confidence score (default: 0.5)
Returns:
Filtered list of CandidateMemory
"""
return [c for c in candidates if c.confidence >= threshold]
def deduplicate(
self,
candidates: list[CandidateMemory]
) -> list[CandidateMemory]:
"""Deduplicate candidates by (category, routing_key, abstract prefix).
Keeps the candidate with the highest confidence per group.
This catches duplicates that slip through dual-run merge,
e.g. same fact extracted from different spans.
"""
seen: dict[tuple[str, str, str], CandidateMemory] = {}
for c in candidates:
key = (c.category, c.routing_key, c.abstract[:80])
if key not in seen or c.confidence > seen[key].confidence:
seen[key] = c
return list(seen.values())