import pytest
from unittest.mock import patch
from openjiuwen_deepsearch.algorithm.research_collector.tool_log import \
is_sensitive_key, get_logged_tool, tool_invoke_log, tool_invoke_log_async
from openjiuwen_deepsearch.common.exception import CustomValueException
MODULE_PATH = "openjiuwen_deepsearch.algorithm.research_collector.tool_log"
class TestIsSensitiveKey:
"""测试 is_sensitive_key 函数"""
def test_is_sensitive_key_positive_cases(self):
"""测试敏感键名"""
sensitive_keys = [
"api_key", "secret_key", "access_key", "key",
"database_url", "redis_url", "endpoint_url",
"auth_token", "bearer_token", "token", "url"
]
for key in sensitive_keys:
assert is_sensitive_key(key) == True, f"Key '{key}' should be sensitive"
def test_is_sensitive_key_negative_cases(self):
"""测试非敏感键名"""
non_sensitive_keys = [
"name", "title", "content", "description",
"count", "limit", "offset", "page",
"query", "search", "filter", "sort"
]
for key in non_sensitive_keys:
assert is_sensitive_key(key) == False, f"Key '{key}' should not be sensitive"
def test_is_sensitive_key_case_insensitive(self):
"""测试大小写不敏感"""
assert is_sensitive_key("API_KEY") == True
assert is_sensitive_key("Api_Key") == True
assert is_sensitive_key("api_Key") == True
assert is_sensitive_key("TOKEN") == True
assert is_sensitive_key("Url") == True
def test_is_sensitive_key_partial_matches(self):
"""测试部分匹配"""
assert is_sensitive_key("my_api_key_here") == True
assert is_sensitive_key("access_token_value") == True
assert is_sensitive_key("database_url_string") == True
assert is_sensitive_key("some_key_name") == True
class TestGetLoggedTool:
"""测试 get_logged_tool 函数"""
def setup_method(self):
class BaseTool:
def __init__(self, name="TestTool"):
self.name = name
def _run(self, *args, **kwargs):
return f"Result from {self.name} with args: {args}, kwargs: {kwargs}"
async def _arun(self, *args, **kwargs):
return f"Async result from {self.name} with args: {args}, kwargs: {kwargs}"
self.BaseTool = BaseTool
def test_get_logged_tool_creates_correct_class(self):
"""测试正确创建日志工具类"""
LoggedTool = get_logged_tool(self.BaseTool)
assert LoggedTool.__name__ == "BaseToolWithLogging"
assert issubclass(LoggedTool, self.BaseTool)
assert hasattr(LoggedTool, '_log_start')
assert hasattr(LoggedTool, '_log_end')
assert hasattr(LoggedTool, '_log_error')
assert hasattr(LoggedTool, '_get_tool_name')
assert hasattr(LoggedTool, '_run')
assert hasattr(LoggedTool, '_arun')
def test_get_logged_tool_format_params_non_sensitive(self):
"""测试非敏感参数的格式化"""
LoggedTool = get_logged_tool(self.BaseTool)
args = ("value1", "list")
kwargs = {"name": "test", "count": "5"}
params = LoggedTool._format_params(args, kwargs)
assert 'value1' in params
assert "list" in params
assert "name='test'" in params
assert "count='5'" in params
def test_get_logged_tool_format_params_sensitive(self):
"""测试敏感参数的过滤"""
LoggedTool = get_logged_tool(self.BaseTool)
args = ("api_key_value", "normal_arg")
kwargs = {"api_key": "secret", "token": "bearer_token", "name": "test"}
params = LoggedTool._format_params(args, kwargs)
assert "api" not in params
assert "key" not in params
assert "token" not in params
assert "normal_arg" in params
assert "name='test'" in params
def test_get_logged_tool_truncate_result(self):
"""测试结果截断"""
LoggedTool = get_logged_tool(self.BaseTool)
short_result = "short result"
assert LoggedTool._truncate_result(short_result) == "'short result'"
long_result = "a" * 150
truncated = LoggedTool._truncate_result(long_result)
assert len(truncated) == 103
assert truncated.endswith("...")
def test_get_logged_tool_get_tool_name(self):
"""测试工具名称提取"""
LoggedTool = get_logged_tool(self.BaseTool)
tool_instance = LoggedTool()
assert tool_instance._get_tool_name() == "BaseTool"
def test_get_logged_tool_run_success(self):
"""测试同步执行成功"""
LoggedTool = get_logged_tool(self.BaseTool)
tool_instance = LoggedTool()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
result = tool_instance._run("arg1", key1="value1")
assert "Result from TestTool" in result
def test_get_logged_tool_run_sensitive_mode(self):
"""测试敏感模式下的同步执行"""
LoggedTool = get_logged_tool(self.BaseTool)
tool_instance = LoggedTool()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = True
result = tool_instance._run("arg1", key1="value1")
mock_logger.info.assert_any_call("[TOOL START] BaseTool._run")
mock_logger.info.assert_any_call("[TOOL END] BaseTool._run | Duration: 0.50s")
assert "Result from TestTool" in result
def test_get_logged_tool_run_exception(self):
"""测试同步执行异常"""
class FailingTool:
def _run(self, *args, **kwargs):
raise ValueError("Test error")
LoggedTool = get_logged_tool(FailingTool)
tool_instance = LoggedTool()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.return_value = 100.0
mock_sensitive.return_value = False
with pytest.raises(CustomValueException):
tool_instance._run("arg1")
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args[0][0]
assert "[TOOL ERROR] FailingTool._run | Error: Test error" in call_args
def test_get_logged_tool_run_exception_sensitive_mode(self):
"""测试敏感模式下的同步执行异常"""
class FailingTool:
def _run(self, *args, **kwargs):
raise ValueError("Test error")
LoggedTool = get_logged_tool(FailingTool)
tool_instance = LoggedTool()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.return_value = 100.0
mock_sensitive.return_value = True
with pytest.raises(CustomValueException):
tool_instance._run("arg1")
mock_logger.error.assert_called_once_with("[TOOL ERROR] FailingTool._run")
@pytest.mark.asyncio
async def test_get_logged_tool_arun_success(self):
"""测试异步执行成功"""
LoggedTool = get_logged_tool(self.BaseTool)
tool_instance = LoggedTool()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
result = await tool_instance._arun("arg1", key1="value1")
assert "Async result from TestTool" in result
@pytest.mark.asyncio
async def test_get_logged_tool_arun_exception(self):
"""测试异步执行异常"""
class AsyncFailingTool:
async def _arun(self, *args, **kwargs):
raise ValueError("Async test error")
LoggedTool = get_logged_tool(AsyncFailingTool)
tool_instance = LoggedTool()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.return_value = 100.0
mock_sensitive.return_value = False
with pytest.raises(CustomValueException):
await tool_instance._arun("arg1")
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args[0][0]
assert "[TOOL ERROR] AsyncFailingTool._arun | Error: Async test error" in call_args
class TestToolInvokeLog:
"""测试 tool_invoke_log 装饰器"""
def test_tool_invoke_log_success(self):
"""测试装饰器成功执行"""
@tool_invoke_log
def my_test_function(arg1, arg2, key1=None):
return f"Result: {arg1}, {arg2}, {key1}"
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
result = my_test_function("value1", "value2", key1="test")
start_calls = [call for call in mock_logger.info.call_args_list
if "[TOOL START]" in call[0][0]]
end_calls = [call for call in mock_logger.info.call_args_list
if "[TOOL END]" in call[0][0]]
assert len(start_calls) == 1
assert len(end_calls) == 1
start_msg = start_calls[0][0][0]
end_msg = end_calls[0][0][0]
assert "my_test_function" in start_msg or "Start to execute tool" in start_msg
assert "value1" in start_msg
assert "value2" in start_msg
assert "my_test_function" in end_msg or "Duration: " in end_msg
assert "Result: value1, value2, test" in end_msg
assert result == "Result: value1, value2, test"
def test_tool_invoke_log_sensitive_args_filtering(self):
"""测试敏感参数过滤"""
@tool_invoke_log
def sensitive_test_function(api_key, token, normal_arg):
return "success"
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
result = sensitive_test_function("secret_key", "bearer_token", "normal_value")
start_calls = [call for call in mock_logger.info.call_args_list
if "[TOOL START]" in call[0][0]]
assert len(start_calls) == 1
start_msg = start_calls[0][0][0]
print(f"DEBUG: Start message = {start_msg}")
assert "secret_key" not in start_msg
assert "bearer_token" not in start_msg
assert "api_key" not in start_msg
assert "token" not in start_msg
assert "normal_value" in start_msg
assert result == "success"
def test_tool_invoke_log_function_name_extraction(self):
"""测试函数名提取"""
@tool_invoke_log
def different_name_function():
return "test"
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
different_name_function()
all_messages = [call[0][0] for call in mock_logger.info.call_args_list]
function_name_in_logs = any("different_name_function" in msg for msg in all_messages)
assert function_name_in_logs or any("Start to execute tool" in msg for msg in all_messages)
class TestToolInvokeLogAsync:
"""测试 tool_invoke_log_async 装饰器"""
@pytest.mark.asyncio
async def test_tool_invoke_log_async_success(self):
"""测试异步装饰器成功执行"""
@tool_invoke_log_async
async def async_test_function(arg1, arg2, key1=None):
return f"Async Result: {arg1}, {arg2}, {key1}"
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
result = await async_test_function("value1", "value2", key1="test")
mock_logger.info.assert_any_call(
"[TOOL END] async_test_function | Args: value1, value2 | Tool result count: 34 | Result content: 'Async Result: value1, value2, test' | Duration: 0.50s")
assert result == "Async Result: value1, value2, test"
@pytest.mark.asyncio
async def test_tool_invoke_log_async_sensitive_mode(self):
"""测试敏感模式下的异步装饰器"""
@tool_invoke_log_async
async def async_test_function(arg1, arg2):
return f"Async Result: {arg1}, {arg2}"
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = True
result = await async_test_function("value1", "value2")
mock_logger.info.assert_any_call("[TOOL START] async_test_function")
mock_logger.info.assert_any_call(
"[TOOL END] async_test_function | Tool result count: 28 | Duration: 0.50s")
assert result == "Async Result: value1, value2"
@pytest.mark.asyncio
async def test_tool_invoke_log_async_exception(self):
"""测试异步装饰器异常处理"""
@tool_invoke_log_async
async def async_failing_function():
raise ValueError("Async test error")
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.return_value = 100.0
mock_sensitive.return_value = False
with pytest.raises(CustomValueException):
await async_failing_function()
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args[0][0]
assert "Exception: ValueError('Async test error')" in call_args
@pytest.mark.asyncio
async def test_tool_invoke_log_async_exception_sensitive_mode(self):
"""测试敏感模式下的异步装饰器异常处理"""
@tool_invoke_log_async
async def async_failing_function():
raise ValueError("Async test error")
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger") as mock_logger:
mock_time.return_value = 100.0
mock_sensitive.return_value = True
with pytest.raises(CustomValueException):
await async_failing_function()
mock_logger.error.assert_called_with("[TOOL ERROR] async_failing_function | Raise exception")
@pytest.mark.asyncio
async def test_tool_invoke_log_async_sync_function_returning_coroutine(self):
"""测试同步函数返回协程对象时,装饰器仍能正确 await"""
async def _inner():
return {"search_results": [{"content": "ok"}]}
@tool_invoke_log_async
def sync_outer():
return _inner()
with patch(f"{MODULE_PATH}.time.time") as mock_time, \
patch(f"{MODULE_PATH}.LogManager.is_sensitive") as mock_sensitive, \
patch(f"{MODULE_PATH}.logger"):
mock_time.side_effect = [100.0, 100.5]
mock_sensitive.return_value = False
result = await sync_outer()
assert result == {"search_results": [{"content": "ok"}]}