import logging
import math
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from psycopg2.retrievers import RetrievalResult
if TYPE_CHECKING:
from psycopg2.models import BaseModel
logger = logging.getLogger(__name__)
@dataclass
class RetrievalPath:
"""Single retrieval path result."""
name: str
results: List[RetrievalResult]
class FusionStrategy(ABC):
"""Base class for fusion strategies.
Subclasses can choose whether to use weights:
- RRFFusion / WeightedFusion: require weights
- LearningToRankFusion: do not require weights (model determines ranking)
"""
def __init__(self, weights: List[float] = None):
"""
Args:
weights: Optional weight list. Auto-distributed evenly if not provided.
Each weight must be in [0, 1] and sum to 1.
"""
if weights is not None:
for w in weights:
if not 0 <= w <= 1:
raise ValueError(
f"Each weight must be in [0, 1], got {w}"
)
total = sum(weights)
if abs(total - 1.0) > 1e-6:
raise ValueError(
f"Weights must sum to 1, got {weights} (sum={total:.4f})"
)
self.weights = weights
def get_weights(self, num_paths: int) -> List[float]:
if self.weights is not None:
if len(self.weights) != num_paths:
raise ValueError(
f"weights length ({len(self.weights)}) != "
f"number of retrievers ({num_paths})"
)
return self.weights
return [1.0 / num_paths] * num_paths
_fuse_source: str = "fused"
def _accumulate_score(
self,
score_map: Dict[Any, float],
data_map: Dict[Any, Dict],
doc_id: Any,
score: float,
data: Dict
):
"""Accumulate score and merge data for a document.
If the document already exists, adds score and merges new data fields
without overwriting existing ones.
"""
if doc_id in score_map:
score_map[doc_id] += score
for key, value in data.items():
if key not in data_map[doc_id]:
data_map[doc_id][key] = value
else:
score_map[doc_id] = score
data_map[doc_id] = dict(data)
def _build_sorted_results(
self,
score_map: Dict[Any, float],
data_map: Dict[Any, Dict],
top_k: int,
source: str
) -> List[RetrievalResult]:
"""Sort by score descending and build RetrievalResult list."""
sorted_ids = sorted(score_map, key=score_map.get, reverse=True)
return [
RetrievalResult(
id=doc_id,
score=score_map[doc_id],
data=data_map[doc_id],
source=source
)
for doc_id in sorted_ids[:top_k]
]
@abstractmethod
def _score_path(self, path: RetrievalPath, weight: float,
score_map: Dict, data_map: Dict):
"""Score a single retrieval path. Subclasses must implement."""
def fuse(
self,
paths: List[RetrievalPath],
top_k: int = 10
) -> List[RetrievalResult]:
"""Fuse multi-path retrieval results (template method).
Iterates over paths with weights, delegates per-path scoring
to ``_score_path()``, then sorts and returns top-k results.
"""
if not paths:
return []
weights = self.get_weights(len(paths))
score_map: Dict[Any, float] = {}
data_map: Dict[Any, Dict] = {}
for path, weight in zip(paths, weights):
self._score_path(path, weight, score_map, data_map)
return self._build_sorted_results(score_map, data_map, top_k, self._fuse_source)
class RRFFusion(FusionStrategy):
"""Reciprocal Rank Fusion.
RRF formula: score = Σ (weight / (k + rank))
Examples:
# Default even weights
fusion = RRFFusion(k=60)
# Custom weights
fusion = RRFFusion(k=60, weights=[0.6, 0.4])
"""
_fuse_source = "rrf_fused"
def __init__(self, k: int = 60, weights: List[float] = None):
super().__init__(weights)
self.k = k
def _score_path(self, path: RetrievalPath, weight: float,
score_map: Dict, data_map: Dict):
"""Compute RRF scores for a single retrieval path."""
for rank, result in enumerate(path.results, 1):
rrf_score = weight / (self.k + rank)
self._accumulate_score(score_map, data_map, result.id, rrf_score, result.data)
class NormMethod(Enum):
"""Score normalization method."""
ARCTAN = "arctan"
MIN_MAX = "min_max"
NONE = "none"
class WeightedFusion(FusionStrategy):
"""Weighted score fusion.
Normalizes scores from each path then computes weighted sum.
Normalization methods:
- arctan: normalize to [0, 1] using arctan function (default)
- min_max: normalize to [0, 1] using min-max scaling
- none: use raw scores directly (when scores are already on the same scale)
Examples:
# Default arctan normalization
fusion = WeightedFusion(weights=[0.7, 0.3])
# Min-max normalization
fusion = WeightedFusion(weights=[0.7, 0.3], norm_method=NormMethod.MIN_MAX)
# No normalization (scores already on the same scale)
fusion = WeightedFusion(weights=[0.7, 0.3], norm_method=NormMethod.NONE)
"""
_fuse_source = "weighted_fused"
def __init__(
self,
weights: List[float] = None,
norm_method: NormMethod = NormMethod.ARCTAN
):
"""
Args:
weights: Optional weight list. Auto-distributed evenly if not provided.
norm_method: Normalization method, defaults to arctan.
"""
super().__init__(weights)
self.norm_method = norm_method
def _normalize_score(self, score: float, min_score: float = None, score_range: float = None) -> float:
"""Normalize score to [0, 1].
Args:
score: Raw score.
min_score: Minimum score (required for min_max method).
score_range: Score range (required for min_max method).
Returns:
Normalized score.
"""
if self.norm_method == NormMethod.ARCTAN:
return (math.atan(score) / math.pi) + 0.5
elif self.norm_method == NormMethod.MIN_MAX:
if score_range is not None and score_range > 0:
return (score - min_score) / score_range
return 0.5
else:
return score
def _compute_range(self, results: List[RetrievalResult]):
"""Compute min_score and score_range for min_max normalization."""
if self.norm_method != NormMethod.MIN_MAX:
return None, None
scores = [r.score for r in results]
min_score = min(scores)
return min_score, max(scores) - min_score
def _score_path(self, path: RetrievalPath, weight: float,
score_map: Dict, data_map: Dict):
"""Normalize and weight scores for a single retrieval path."""
if not path.results:
return
min_score, score_range = self._compute_range(path.results)
for result in path.results:
normalized = self._normalize_score(result.score, min_score, score_range)
self._accumulate_score(
score_map, data_map,
result.id, normalized * weight, result.data
)
class ModelRerankFusion(FusionStrategy):
"""Model-based reranking fusion strategy.
Uses a model's rerank capability to semantically reorder multi-path recall results.
Unlike RRF/Weighted fusion which score per-path independently, this strategy:
1. Merges multi-path recall results (deduplicate)
2. Extracts document text
3. Calls model rerank API in a single batch
4. Returns results sorted by semantic relevance
Example:
>>> from psycopg2.models import create_model
>>>
>>> model = create_model("cohere", api_key="xxx")
>>> fusion = ModelRerankFusion(model=model, query="machine learning")
>>>
>>> results = client.hybrid_search(
... "documents",
... retrievers=[vec_ret, ft_ret],
... fusion_strategy=fusion
... )
"""
_fuse_source = "model_rerank"
def __init__(
self,
model: "BaseModel",
query: str,
text_field: str = "content",
fallback_to_rrf: bool = True
):
"""
Args:
model: Pre-constructed model instance (must support rerank).
query: Query text (used for rerank).
text_field: Text field name used for rerank input.
fallback_to_rrf: Whether to fall back to RRF if rerank fails.
"""
super().__init__(weights=None)
self.model = model
self.query = query
self.text_field = text_field
self.fallback_to_rrf = fallback_to_rrf
def _extract_text(self, data: Dict[str, Any]) -> str:
"""Extract text from document data."""
return data.get(self.text_field, str(data))
@staticmethod
def _update_merged_entry(entry: Dict, result: 'RetrievalResult') -> None:
"""Update an existing merged entry with a new result.
Keeps the highest score and merges data fields without
overwriting existing ones (consistent with base class
``_accumulate_score``).
"""
if result.score > entry["score"]:
entry["score"] = result.score
for key, value in result.data.items():
if key not in entry["data"]:
entry["data"][key] = value
def _merge_results(self, paths: List[RetrievalPath]) -> Dict[Any, Dict]:
"""Merge multi-path recall results (deduplicate, keep highest score).
Data merge follows the same convention as the base class
``_accumulate_score``: new fields are added but existing fields
are never overwritten.
"""
merged = {}
for path in paths:
for result in path.results:
doc_id = result.id
if doc_id not in merged:
merged[doc_id] = {
"data": dict(result.data),
"score": result.score,
}
else:
self._update_merged_entry(merged[doc_id], result)
return merged
def _score_path(self, path: RetrievalPath, weight: float,
score_map: Dict, data_map: Dict):
"""Not used — ModelRerankFusion overrides fuse() directly."""
raise NotImplementedError(
"ModelRerankFusion uses model rerank instead of per-path scoring"
)
def fuse(
self,
paths: List[RetrievalPath],
top_k: int = 10
) -> List[RetrievalResult]:
"""Fuse results using model rerank.
Overrides the base template method because reranking requires
a single batch call across all merged documents.
Args:
paths: Multi-path retrieval results.
top_k: Maximum number of results to return.
Returns:
Fused retrieval result list.
"""
if not paths:
return []
merged = self._merge_results(paths)
if not merged:
return []
doc_ids = list(merged.keys())
documents = [self._extract_text(merged[did]["data"]) for did in doc_ids]
try:
rerank_results = self.model.rerank(
query=self.query,
documents=documents,
top_n=top_k
)
except Exception as e:
logger.warning(f"Model rerank failed: {e}")
if self.fallback_to_rrf:
logger.info("Falling back to RRF fusion")
return RRFFusion(k=60).fuse(paths, top_k)
raise
results = []
for r in rerank_results:
idx = r["index"]
doc_id = doc_ids[idx]
score = r["score"]
results.append(RetrievalResult(
id=doc_id,
score=score,
data=merged[doc_id]["data"],
source=self._fuse_source
))
return results[:top_k]
class MultiRetrievalEngine:
"""Multi-path retrieval engine.
Examples:
# 1. Create retrievers
vector_ret = VectorRetriever(query_vector, metric="cosine")
text_ret = FullTextRetriever(query_text)
# 2. Create engine (weights configured in fusion strategy)
engine = MultiRetrievalEngine(
retrievers=[vector_ret, text_ret],
fusion_strategy=RRFFusion(k=60, weights=[0.6, 0.4])
)
# 3. Execute search
results = engine.search(client, "documents", top_k=10)
# Example 2: default even weights
engine = MultiRetrievalEngine(
retrievers=[vector_ret, text_ret],
fusion_strategy=RRFFusion() # auto even [0.5, 0.5]
)
"""
DEFAULT_TIMEOUT = 30
def __init__(
self,
retrievers: List,
fusion_strategy: FusionStrategy = None
):
if not retrievers:
raise ValueError("retrievers list must not be empty")
self.retrievers = retrievers
self.fusion_strategy = fusion_strategy or RRFFusion()
def search(
self,
client,
table_name: str,
top_k: int = 10,
filter_condition: str = None,
filter_params: Dict = None,
output_columns: List[str] = None,
per_path_top_k: int = None,
parallel: bool = True,
max_workers: int = None,
timeout: float = None,
**kwargs
) -> List[Dict]:
"""Execute multi-path retrieval with optional parallelism.
Args:
client: Database client.
table_name: Table name.
top_k: Number of final results to return.
filter_condition: SQL WHERE clause for filtering.
filter_params: Parameters for filter condition.
output_columns: Columns to include in output.
per_path_top_k: Results per retrieval path (defaults to top_k * 2).
parallel: Whether to run in parallel (default True, uses connection pool).
max_workers: Max thread count (defaults to number of retrievers).
timeout: Per-retriever timeout in seconds (default 30s).
**kwargs: Extra arguments passed to each retriever.
Returns:
Fused results as list of dicts.
"""
if per_path_top_k is None:
per_path_top_k = top_k * 2
if timeout is None:
timeout = self.DEFAULT_TIMEOUT
retrieve_kwargs = dict(
client=client,
table_name=table_name,
top_k=per_path_top_k,
filter_condition=filter_condition,
filter_params=filter_params,
output_columns=output_columns,
**kwargs
)
if parallel:
paths = self._search_parallel(retrieve_kwargs, max_workers, timeout)
else:
paths = self._search_sequential(retrieve_kwargs)
if not paths:
return []
fused_results = self.fusion_strategy.fuse(paths, top_k)
return [
{
**r.data,
'id': r.id,
'score': r.score,
'source': r.source,
}
for r in fused_results
]
@staticmethod
def _safe_retrieve(retriever, retrieve_kwargs: Dict) -> RetrievalPath:
"""Execute a single retriever, returning empty path on failure."""
try:
results = retriever.retrieve(**retrieve_kwargs)
return RetrievalPath(name=retriever.get_name(), results=results)
except Exception as e:
logger.warning(f"{retriever.get_name()} retrieval failed: {e}")
return RetrievalPath(name=retriever.get_name(), results=[])
def _search_sequential(self, retrieve_kwargs: Dict) -> List[RetrievalPath]:
"""Execute each retriever sequentially.
Note: Failed retrievers produce an empty-result path (not skipped)
to preserve positional alignment with weights.
"""
return [
self._safe_retrieve(retriever, retrieve_kwargs)
for retriever in self.retrievers
]
def _search_parallel(
self,
retrieve_kwargs: Dict,
max_workers: int = None,
timeout: float = None
) -> List[RetrievalPath]:
"""Execute each retriever in parallel using a thread pool.
Note: Failed/timed-out retrievers produce an empty-result path (not skipped)
to preserve positional alignment with weights.
"""
if max_workers is None:
max_workers = len(self.retrievers)
paths: List[Optional[RetrievalPath]] = [None] * len(self.retrievers)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(self._safe_retrieve, retriever, retrieve_kwargs)
for retriever in self.retrievers
]
for i, future in enumerate(futures):
try:
paths[i] = future.result(timeout=timeout)
except FuturesTimeoutError:
name = self.retrievers[i].get_name()
logger.warning(f"{name} retrieval timed out after {timeout}s")
paths[i] = RetrievalPath(name=name, results=[])
except Exception as e:
name = self.retrievers[i].get_name()
logger.warning(f"{name} retrieval task failed: {e}")
paths[i] = RetrievalPath(name=name, results=[])
return paths