import logging
import os
import sys
from pathlib import Path
import pytest
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.utils.log_utils.log_handlers import SafeRotatingFileHandler
from openjiuwen_deepsearch.utils.log_utils.log_common import DEFAULT_MAX_LOG_MESSAGE_LENGTH
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
@pytest.fixture
def clean_logs(tmp_path):
safe_base = tmp_path / "logs"
safe_base.mkdir(parents=True)
LogManager._SAFE_BASE = str(safe_base)
LogManager._initialized = False
third_party_states = {
logger_name: (
logging.getLogger(logger_name).disabled,
logging.getLogger(logger_name).propagate,
logging.getLogger(logger_name).level,
)
for logger_name in LogManager._THIRD_PARTY_LOGGERS
}
yield safe_base
root_logger = logging.getLogger()
for handler in list(root_logger.handlers):
handler.flush()
handler.close()
root_logger.handlers.clear()
for logger_name, (disabled, propagate, level) in third_party_states.items():
logger_obj = logging.getLogger(logger_name)
logger_obj.disabled = disabled
logger_obj.propagate = propagate
logger_obj.setLevel(level)
LogManager._initialized = False
def _flush_root_handlers():
root_logger = logging.getLogger()
for handler in root_logger.handlers:
handler.flush()
def _read_common_log(log_root: Path) -> str:
common_log = log_root / "common" / "common.log"
if not common_log.exists():
return ""
return common_log.read_text(encoding="utf-8")
def test_safe_log_dir_valid(clean_logs):
target = clean_logs / "sub"
target.mkdir()
result = LogManager._safe_log_dir(str(target))
assert result == str(target.resolve())
def test_safe_log_dir_invalid_not_subdir(clean_logs):
parent = Path(clean_logs).parent
outside = parent / "not_inside"
outside.mkdir()
with pytest.raises(CustomValueException) as e:
LogManager._safe_log_dir(str(outside))
assert str(StatusCode.PARAM_CHECK_ERROR_LOG_DIR_UNSAFE.code) in str(e.value)
def test_safe_log_dir_invalid_path(clean_logs):
with pytest.raises(CustomValueException):
LogManager._safe_log_dir("/???/illegal_path")
def test_logmanager_init_once(clean_logs, monkeypatch):
"""测试 init 只执行一次"""
LogManager._initialized = False
called = {"common": 0, "metrics": 0, "interface": 0}
def mock_common(*args, **kwargs):
called["common"] += 1
def mock_metrics(*args, **kwargs):
called["metrics"] += 1
def mock_interface(*args, **kwargs):
called["interface"] += 1
monkeypatch.setattr("openjiuwen_deepsearch.utils.log_utils.log_manager.setup_common_logger", mock_common)
monkeypatch.setattr("openjiuwen_deepsearch.utils.log_utils.log_manager.setup_metrics_logger", mock_metrics)
monkeypatch.setattr("openjiuwen_deepsearch.utils.log_utils.log_manager.setup_interface_logger", mock_interface)
log_dir = str(clean_logs / "sub")
LogManager.init(log_dir=log_dir, is_sensitive=False)
LogManager.init(log_dir=log_dir, is_sensitive=True)
assert called["common"] == 1
assert called["metrics"] == 1
assert called["interface"] == 1
assert LogManager.is_sensitive() is False
def test_is_sensitive_set(clean_logs):
LogManager._initialized = False
LogManager.init(log_dir=str(clean_logs), is_sensitive=True)
assert LogManager.is_sensitive() is True
def test_init_validation_errors(clean_logs):
"""测试 LogManager.init 的各类参数校验失败场景"""
LogManager._initialized = False
test_cases = [
dict(
kwargs={"is_sensitive": "not_bool"},
expected_code=200020,
),
dict(
kwargs={"level": 123},
expected_code=200005,
),
dict(
kwargs={"level": "OTHER_LEVEL"},
expected_code=200014,
),
dict(
kwargs={"max_bytes": "100MB"},
expected_code=200005,
),
dict(
kwargs={"max_bytes": -1},
expected_code=200025,
),
dict(
kwargs={"max_bytes": 2000 * 1024 * 1024},
expected_code=200025,
),
dict(
kwargs={"backup_count": 10.5},
expected_code=200005,
),
dict(
kwargs={"backup_count": -1},
expected_code=200025,
),
dict(
kwargs={"backup_count": 1001},
expected_code=200025,
),
]
for case in test_cases:
LogManager._initialized = False
params = {
"log_dir": str(clean_logs / "sub"),
}
params.update(case["kwargs"])
with pytest.raises(CustomValueException) as exc:
LogManager.init(**params)
assert exc.value.error_code == case["expected_code"]
def test_safe_log_dir_sets_permission(clean_logs):
"""测试安全路径验证能正确设置目录权限"""
target = clean_logs / "new_sub_dir"
result_path = Path(LogManager._safe_log_dir(str(target)))
assert result_path.exists()
if sys.platform == "win32":
assert not os.access(result_path, os.W_OK) == False
return
else:
mode = result_path.stat().st_mode & 0o777
assert mode == 0o750, f"Expected mode 0o750, got {oct(mode)}"
def test_safe_rotating_file_handler_permissions(clean_logs):
"""测试SafeRotatingFileHandler能否正确设置文件和目录权限"""
log_file = clean_logs / "test_dir" / "test.log"
handler = SafeRotatingFileHandler(
filename=str(log_file),
maxBytes=1024,
backupCount=2,
delay=True
)
logger = logging.getLogger("test_perm")
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.info("First message")
log_dir = log_file.parent
for i in range(50):
logger.info(f"Message {i} to fill log")
if sys.platform == "win32":
assert log_dir.exists()
assert log_file.exists()
assert os.access(log_dir, os.W_OK)
print("Windows: 跳过POSIX权限检查,验证文件和目录创建、轮转逻辑。")
else:
dir_mode = log_dir.stat().st_mode & 0o777
assert dir_mode == 0o750, f"目录权限不符: 期望 0o750, 实际 {oct(dir_mode)}"
file_mode = log_file.stat().st_mode & 0o777
assert file_mode == 0o640, f"活跃日志文件权限不符: 期望 0o640, 实际 {oct(file_mode)}"
handler.doRollover()
for i in range(1, handler.backupCount + 1):
backup = Path(f"{log_file}.{i}")
if backup.exists():
backup_mode = backup.stat().st_mode & 0o777
assert backup_mode == 0o440, f"备份文件 {i} 权限不符: 期望 0o440, 实际 {oct(backup_mode)}"
handler.close()
def test_common_log_truncates_long_message(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
logger = logging.getLogger("openjiuwen_deepsearch.test_log")
long_message = "HEAD" * 500 + "BODY" * 800 + "TAIL" * 500
logger.info(long_message)
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert "truncated, original_len=" in common_log_text
assert "HEADHEADHEAD" in common_log_text
assert "TAILTAILTAIL" in common_log_text
assert long_message not in common_log_text
def test_common_log_keeps_boundary_message_without_truncation(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
logger = logging.getLogger("openjiuwen_deepsearch.boundary")
boundary_message = "a" * DEFAULT_MAX_LOG_MESSAGE_LENGTH
logger.info(boundary_message)
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert boundary_message in common_log_text
assert "truncated, original_len=" not in common_log_text
def test_skip_truncation_preserves_full_message(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
logger = logging.getLogger("openjiuwen_deepsearch.key_log")
long_message = "IMPORTANT-" + ("0123456789" * 700)
logger.info(long_message, extra={"skip_truncation": True})
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert long_message in common_log_text
assert "truncated, original_len=" not in common_log_text
def test_exception_log_truncates_message_and_keeps_traceback(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
logger = logging.getLogger("openjiuwen_deepsearch.exception_log")
try:
raise ValueError("boom")
except ValueError:
logger.exception("X" * (DEFAULT_MAX_LOG_MESSAGE_LENGTH + 200))
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert "truncated, original_len=" in common_log_text
assert "Traceback (most recent call last)" in common_log_text
assert "ValueError: boom" in common_log_text
def test_third_party_debug_info_are_filtered_but_warning_error_are_kept(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
for logger_name in LogManager._THIRD_PARTY_LOGGERS:
logger_obj = logging.getLogger(logger_name)
assert logger_obj.disabled is False
assert logger_obj.propagate is True
assert logger_obj.level == logging.WARNING
third_party_logger = logging.getLogger("openai._base_client")
third_party_logger.info("third-party-info-should-not-appear")
third_party_logger.warning("third-party-warning-should-appear")
third_party_logger.error("third-party-error-should-appear")
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert "third-party-info-should-not-appear" not in common_log_text
assert "third-party-warning-should-appear" in common_log_text
assert "third-party-error-should-appear" in common_log_text
def test_project_logger_is_allowed_to_write_common_log(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
logger = logging.getLogger("server.test_module")
logger.warning("project-warning-should-appear")
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert "project-warning-should-appear" in common_log_text
def test_representative_key_log_can_bypass_truncation(clean_logs):
LogManager.init(log_dir=str(clean_logs), level="DEBUG", is_sensitive=False)
logger = logging.getLogger(
"openjiuwen_deepsearch.algorithm.source_trace.citation_checker_research"
)
full_result_text = "=============== result text =================:\n" + ("RESULT-" * 900)
logger.info(full_result_text, extra={"skip_truncation": True})
_flush_root_handlers()
common_log_text = _read_common_log(clean_logs)
assert full_result_text in common_log_text