"""Tests for ``KnowledgeBaseRetriever``."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openjiuwen.core.retrieval.common.config import RetrievalConfig
from openjiuwen.core.retrieval.common.retrieval_result import RetrievalResult
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.base_retriever import (
MilvusBaseRetriever,
RetrieveConfig,
)
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.retriever import (
KnowledgeBaseRetriever,
)
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.embedder import (
AbstractEmbedder,
)
@pytest.fixture(autouse=True)
def _allow_localhost_embedding_url(monkeypatch):
"""Bypass embedding-service SSRF validation; the stub embedder never makes HTTP calls."""
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_tools.retrieval.embedder.validate_embedding_service_url",
lambda url: None,
)
class _StubEmbedder(AbstractEmbedder):
"""Minimal embedder for constructing ``KnowledgeBaseRetriever`` in tests (Milvus not connected)."""
def __init__(self):
super().__init__(
pretrained_model="stub",
api_token=bytearray(b"x"),
api_url="http://localhost",
model_dim=4,
)
def get_query_instruction(self, query: str) -> str:
return ""
def encode(self, input_texts: list[str], is_query: bool = False):
return [[0.0] * self.embed_dim for _ in input_texts]
def _kb_retriever(kb):
with patch.object(
KnowledgeBaseRetriever,
"_create_knowledge_base",
return_value=kb,
):
return KnowledgeBaseRetriever(
"localhost",
"19530",
"default",
"",
_StubEmbedder(),
)
def _make_kb(results_per_query):
"""``results_per_query`` is a list of lists of ``RetrievalResult`` (one list per retrieve call)."""
kb = MagicMock()
kb.retrieve = AsyncMock(side_effect=results_per_query)
return kb
def test_knowledge_base_retriever_is_milvus_base_and_skips_client():
kb = _make_kb([[RetrievalResult(text="x", score=1.0, chunk_id="id0")]])
r = _kb_retriever(kb)
assert isinstance(r, MilvusBaseRetriever)
assert r.client is None
assert r.embedder is not None
assert r.collection_name == ""
def test_retrieve_single_query_returns_json_and_ids():
kb = _make_kb(
[
[
RetrievalResult(text="chunk a", score=0.9, chunk_id="c-1", metadata={}),
RetrievalResult(text="chunk b", score=0.8, doc_id="d-2", chunk_id=None, metadata={}),
]
]
)
retriever = _kb_retriever(kb)
text, id_list = retriever.retrieve(
RetrieveConfig(query=["what is X?"], top_k=5)
)
payload = json.loads(text)
assert payload["query"] == "what is X?"
assert payload["results"] == ["chunk a", "chunk b"]
assert id_list == ["c-1", "d-2"]
kb.retrieve.assert_awaited_once()
assert kb.retrieve.await_args.args[0] == "what is X?"
assert kb.retrieve.await_args.kwargs["config"] == RetrievalConfig(top_k=5)
def test_retrieve_multiple_queries_joined_with_blank_line():
kb = _make_kb(
[
[RetrievalResult(text="one", score=1.0, chunk_id="1")],
[RetrievalResult(text="two", score=1.0, chunk_id="2")],
]
)
retriever = _kb_retriever(kb)
text, id_list = retriever.retrieve(
RetrieveConfig(query=["first", "second"], top_k=3)
)
parts = text.split("\n\n")
assert len(parts) == 2
assert json.loads(parts[0]) == {"query": "first", "results": ["one"]}
assert json.loads(parts[1]) == {"query": "second", "results": ["two"]}
assert id_list == ["1", "2"]
assert kb.retrieve.await_count == 2
def test_top_k_multiplied_by_factor():
kb = _make_kb([[RetrievalResult(text="t", score=1.0, chunk_id="x")]])
retriever = _kb_retriever(kb)
retriever.retrieve(
RetrieveConfig(query=["q"], top_k=4, top_k_multiply_factor=3)
)
assert kb.retrieve.await_args.kwargs["config"].top_k == 12
@pytest.mark.parametrize(
"result,expected_id",
[
(RetrievalResult(text="a", score=0.0, chunk_id="cid"), "cid"),
(RetrievalResult(text="a", score=0.0, doc_id="did", chunk_id=None), "did"),
(
RetrievalResult(
text="a",
score=0.0,
metadata={"id": "mid"},
),
"mid",
),
(
RetrievalResult(
text="a",
score=0.0,
metadata={"chunk_id": "mc"},
),
"mc",
),
(RetrievalResult(text="a", score=0.0, metadata={}), ""),
],
)
def test_result_id_resolution(result, expected_id):
assert KnowledgeBaseRetriever._result_id(result) == expected_id
def test_save_as_writes_json_with_raw_results(tmp_path):
out = tmp_path / "out.json"
kb = _make_kb(
[
[
RetrievalResult(
text="body",
score=0.5,
chunk_id="ch-0",
metadata={"k": "v"},
)
]
]
)
retriever = _kb_retriever(kb)
retriever.retrieve(
RetrieveConfig(query=["q"], top_k=2, save_as=str(out))
)
assert out.is_file()
saved = json.loads(out.read_text(encoding="utf-8"))
assert len(saved) == 1
assert saved[0]["query"] == "q"
assert saved[0]["results"] == ["body"]
assert "raw_results" in saved[0]
assert saved[0]["raw_results"][0]["text"] == "body"
assert saved[0]["raw_results"][0]["chunk_id"] == "ch-0"