from unittest.mock import Mock, patch, MagicMock, AsyncMock
import aiohttp
import pytest
import requests
from pydantic import SecretStr
from openjiuwen_deepsearch.framework.openjiuwen.tools.search_api.local_search_api.api_wrapper import \
LocalDatasetAPIWrapper
MODULE_PATH = "openjiuwen_deepsearch.framework.openjiuwen.tools.search_api.local_search_api.api_wrapper"
class TestLocalDatasetAPIWrapper:
"""测试 LocalDatasetAPIWrapper 类"""
def setup_method(self):
"""每个测试方法运行前都会执行"""
self.search_api_key = bytearray(b"test_local_api_key_123")
self.search_url = SecretStr("https://api.localdataset.com/search")
self.search_datasets = ["dataset_1", "dataset_2", "dataset_3"]
self.max_local_search_results = 5
self.recall_threshold = 0.7
self.wrapper = LocalDatasetAPIWrapper(
search_api_key=self.search_api_key,
search_url=self.search_url,
search_datasets=self.search_datasets,
max_local_search_results=self.max_local_search_results,
recall_threshold=self.recall_threshold
)
def test_initialization(self):
"""测试类初始化"""
assert self.wrapper.search_api_key == self.search_api_key
assert self.wrapper.search_url.get_secret_value() == "https://api.localdataset.com/search"
assert self.wrapper.search_datasets == ["dataset_1", "dataset_2", "dataset_3"]
assert self.wrapper.max_local_search_results == 5
assert self.wrapper.recall_threshold == 0.7
assert self.wrapper.extension is None
def test_initialization_default_values(self):
"""测试默认值初始化"""
wrapper = LocalDatasetAPIWrapper(
search_api_key=bytearray(b"key"),
search_url=SecretStr("https://example.com")
)
assert wrapper.search_datasets == []
assert wrapper.max_local_search_results == 5
assert wrapper.recall_threshold == 0.5
def test_results_method(self):
"""测试同步results方法"""
mock_results = [
{"content": "Result 1", "similarity": 0.9},
{"content": "Result 2", "similarity": 0.8}
]
with patch.object(self.wrapper, '_search_api_results') as mock_search:
mock_search.return_value = mock_results
result = self.wrapper.results("test query")
mock_search.assert_called_once_with("test query", num=5)
assert result == mock_results
@pytest.mark.asyncio
async def test_aresults_method(self):
"""测试异步aresults方法"""
mock_results = [
{"content": "Async Result 1", "similarity": 0.95},
{"content": "Async Result 2", "similarity": 0.85}
]
with patch.object(self.wrapper, '_async_search_api_results') as mock_async_search:
mock_async_search.return_value = mock_results
result = await self.wrapper.aresults("test query")
mock_async_search.assert_called_once_with("test query", num=5)
assert result == mock_results
def test_build_headers(self):
"""测试构建请求头和数据"""
headers = self.wrapper.build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["Accept"] == "application/json"
def test_build_headers_empty_datasets(self):
"""测试构建空数据集的请求头"""
wrapper = LocalDatasetAPIWrapper(
search_api_key=bytearray(b"key"),
search_url=SecretStr("https://example.com"),
search_datasets=[]
)
body_params, query_params = wrapper.build_request_params("test query")
assert body_params["query"] == "test query"
assert query_params["top_k"] == 5
assert query_params["recall_threshold"] == 0.5
def test_build_headers_custom_threshold(self):
"""测试自定义相似度阈值"""
wrapper = LocalDatasetAPIWrapper(
search_api_key=bytearray(b"key"),
search_url=SecretStr("https://example.com"),
recall_threshold=0.9
)
_, query_params = wrapper.build_request_params("test query")
assert query_params["recall_threshold"] == 0.9
def test_search_api_results_success(self):
"""测试同步搜索API成功情况"""
mock_response = Mock()
mock_response.json.return_value = {
"output_list": [
{"content": "Result 1", "similarity": 0.95, "metadata": {"source": "doc1"}},
{"content": "Result 2", "similarity": 0.85, "metadata": {"source": "doc2"}},
{"content": "Result 3", "similarity": 0.75, "metadata": {"source": "doc3"}},
{"content": "Result 4", "similarity": 0.65, "metadata": {"source": "doc4"}},
{"content": "Result 5", "similarity": 0.55, "metadata": {"source": "doc5"}}
]
}
mock_response.raise_for_status.return_value = None
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config:
mock_build_headers.return_value = {"X-Auth-Token": "key", "Content-Type": "application/json"}
mock_post.return_value = mock_response
mock_ssl_config.return_value = (True, "/path/to/cert")
result = self.wrapper._search_api_results("test query", num=3)
assert len(result) == 3
assert result[0]["content"] == "Result 1"
assert result[1]["content"] == "Result 2"
assert result[2]["content"] == "Result 3"
def test_search_api_results_ssl_verify_false(self):
"""测试SSL验证关闭的情况"""
mock_response = Mock()
mock_response.json.return_value = {"retrieve_result_list": []}
mock_response.raise_for_status.return_value = None
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_post.return_value = mock_response
mock_ssl_config.return_value = (False, None)
result = self.wrapper._search_api_results("test query", num=5)
mock_post.assert_called_once()
call_kwargs = mock_post.call_args[1]
assert call_kwargs['verify'] is False
def test_search_api_results_request_exception(self):
"""测试同步搜索请求异常"""
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_post.side_effect = requests.exceptions.RequestException("Connection failed")
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = False
result = self.wrapper._search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with(
"Search request failed! Error: Connection failed"
)
def test_search_api_results_request_exception_sensitive_mode(self):
"""测试敏感模式下的同步搜索请求异常"""
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_post.side_effect = requests.exceptions.RequestException("Connection failed")
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = True
result = self.wrapper._search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with("Search request failed!")
def test_search_api_results_unexpected_response_format(self):
"""测试意外的响应格式"""
mock_response = Mock()
mock_response.json.return_value = {'invalid_key': 'unexpected_data'}
mock_response.raise_for_status.return_value = None
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_post.return_value = mock_response
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = False
result = self.wrapper._search_api_results("test query", num=5)
assert result == []
def test_search_api_results_non_list_retrieve_result_list(self):
"""测试敏感模式下output_list不是列表的情况"""
mock_response = Mock()
mock_response.json.return_value = {'output_list': "not_a_list"}
mock_response.raise_for_status.return_value = None
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_post.return_value = mock_response
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = True
result = self.wrapper._search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with("Unexpected search request response!")
def test_search_api_results_log_is_not_sensitive(self):
"""测试非敏感模式下output_list不是列表的情况"""
mock_response = Mock()
mock_response.json.return_value = {'output_list': "not_a_list"}
mock_response.raise_for_status.return_value = None
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.requests.post") as mock_post, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_post.return_value = mock_response
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = False
result = self.wrapper._search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with("Unexpected response! original result: not_a_list")
@pytest.mark.asyncio
async def test_async_search_api_results_success(self):
"""测试异步搜索成功"""
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.aiohttp.ClientSession") as mock_client_session, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json.return_value = {
"output_list": [
{"content": "Result 1", "similarity": 0.95, "metadata": {"source": "doc1"}},
{"content": "Result 2", "similarity": 0.85, "metadata": {"source": "doc2"}},
{"content": "Result 3", "similarity": 0.75, "metadata": {"source": "doc3"}},
{"content": "Result 4", "similarity": 0.65, "metadata": {"source": "doc4"}},
{"content": "Result 5", "similarity": 0.55, "metadata": {"source": "doc5"}}
]
}
mock_post_context = MagicMock()
mock_post_context.__aenter__.return_value = mock_response
mock_post_context.__aexit__.return_value = None
mock_session = MagicMock()
mock_session.post.return_value = mock_post_context
mock_session_context = MagicMock()
mock_session_context.__aenter__.return_value = mock_session
mock_session_context.__aexit__.return_value = None
mock_client_session.return_value = mock_session_context
result = await self.wrapper._async_search_api_results("test query", num=5)
assert len(result) == 5
@pytest.mark.asyncio
async def test_async_search_api_results_with_sensitive_client_error(self):
"""测试异步搜索客户端错误"""
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.aiohttp.ClientSession") as mock_session_class, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = True
mock_session_class.side_effect = aiohttp.ClientError("Connection failed")
result = await self.wrapper._async_search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with("Search request failed!")
@pytest.mark.asyncio
async def test_async_search_api_results_ssl_false_client_error(self):
"""测试异步搜索客户端错误"""
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.aiohttp.ClientSession") as mock_session_class, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_ssl_config.return_value = (False, "/path/to/cert")
mock_sensitive.return_value = False
mock_session_class.side_effect = aiohttp.ClientError("Connection failed")
result = await self.wrapper._async_search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with(
"Search request failed! Error: Connection failed"
)
@pytest.mark.asyncio
async def test_async_search_api_results_client_error(self):
"""测试异步搜索客户端错误"""
with patch.object(self.wrapper, 'build_headers') as mock_build_headers, \
patch(f"{MODULE_PATH}.aiohttp.ClientSession") as mock_session_class, \
patch(f"{MODULE_PATH}.SslUtils.get_ssl_config") as mock_ssl_config, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_build_headers.return_value = ({}, "https://api.example.com", {})
mock_ssl_config.return_value = (True, "/path/to/cert")
mock_sensitive.return_value = False
mock_session_class.side_effect = aiohttp.ClientError("Connection failed")
result = await self.wrapper._async_search_api_results("test query", num=5)
assert result == []
mock_logger.error.assert_called_once_with(
"Search request failed! Error: Connection failed"
)