from unittest.mock import patch, MagicMock, call
import unittest
from parameterized import parameterized
from sql.storage import DynamicStorage
import json
class TestDynamicStorage(unittest.TestCase):
"""
测试 DynamicStorage 类的行为
"""
# 存储类型到类的映射
storage_classes = {
"local": "FileSystemStorage",
"sftp": "SFTPStorage",
"s3c": "S3Boto3Storage",
"azure": "AzureStorage",
}
def setUp(self):
"""通用配置数据"""
self.local_config = {
"storage_type": "local",
"local_path": "/tmp/files/",
}
self.sftp_config = {
"storage_type": "sftp",
"sftp_host": "sftp.example.com",
"sftp_user": "user",
"sftp_password": "pass",
"sftp_port": 22,
"sftp_path": "/uploads/",
}
self.s3c_config = {
"storage_type": "s3c",
"s3c_access_key_id": "AKIA...",
"s3c_access_key_secret": "secret",
"s3c_endpoint": "http://s3.example.com",
"s3c_bucket_name": "my-bucket",
"s3c_region": "us-east-1",
"s3c_path": "data/",
}
self.azure_config = {
"storage_type": "azure",
"azure_account_name": "myaccount",
"azure_account_key": "azurekey",
"azure_container": "container",
"azure_path": "azure-data/",
}
@parameterized.expand(
[
(
"local",
"FileSystemStorage",
{
"location": "/tmp/files/",
"base_url": "/tmp/files/",
},
),
(
"sftp",
"SFTPStorage",
{
"host": "sftp.example.com",
"params": {
"username": "user",
"password": "pass",
"port": 22,
},
"root_path": "/uploads/",
},
),
(
"s3c",
"S3Boto3Storage",
{
"access_key": "AKIA...",
"secret_key": "secret",
"bucket_name": "my-bucket",
"region_name": "us-east-1",
"endpoint_url": "http://s3.example.com",
"location": "data/",
"file_overwrite": False,
},
),
(
"azure",
"AzureStorage",
{
"account_name": "myaccount",
"account_key": "azurekey",
"azure_container": "container",
"location": "azure-data/",
},
),
]
)
def test_storage_initialization(self, storage_type, storage_class, expected_kwargs):
"""参数化测试存储后端初始化"""
config_dict = getattr(self, f"{storage_type}_config")
with patch(f"sql.storage.{storage_class}") as mock_storage:
DynamicStorage(config_dict=config_dict)
# 验证正确的存储类被调用
mock_storage.assert_called_once()
# 验证调用参数
actual_kwargs = mock_storage.call_args[1]
self.assertDictEqual(actual_kwargs, expected_kwargs)
@parameterized.expand(
[
("save", "test.txt", "content"),
("open", "test.txt", "rb"),
("delete", "test.txt", None),
("exists", "test.txt", None),
("size", "test.txt", None),
]
)
def test_method_proxying(self, method, filename, content):
"""测试方法代理到底层存储"""
mock_storage = MagicMock()
with patch.object(DynamicStorage, "_init_storage", return_value=mock_storage):
storage = DynamicStorage(config_dict=self.local_config)
# 调用代理方法
method_call = getattr(storage, method)
if content:
method_call(filename, content)
else:
method_call(filename)
# 验证底层存储方法被调用
underlying_call = getattr(mock_storage, method)
if content:
underlying_call.assert_called_once_with(filename, content)
else:
underlying_call.assert_called_once_with(filename)
def test_close_behavior(self):
"""测试 close 方法的两种场景"""
# 场景1:底层存储有 close 方法
mock_storage_with_close = MagicMock()
mock_storage_with_close.close = MagicMock()
with patch.object(
DynamicStorage, "_init_storage", return_value=mock_storage_with_close
):
storage = DynamicStorage(config_dict=self.sftp_config)
storage.close()
mock_storage_with_close.close.assert_called_once()
# 场景2:底层存储无 close 方法
mock_storage_without_close = MagicMock()
delattr(mock_storage_without_close, "close")
with patch.object(
DynamicStorage, "_init_storage", return_value=mock_storage_without_close
):
storage = DynamicStorage(config_dict=self.local_config)
try:
storage.close() # 不应报错
except Exception as e:
self.fail(f"close() raised {e}")
def test_storage_operation_exceptions(self):
"""测试存储操作异常处理"""
for method in ["open", "save", "delete", "exists", "size"]:
with self.subTest(method=method):
with patch.object(DynamicStorage, "_init_storage") as mock_init:
mock_storage = MagicMock()
mock_init.return_value = mock_storage
# 模拟底层存储抛出异常
getattr(mock_storage, method).side_effect = Exception("存储错误")
storage = DynamicStorage(config_dict=self.local_config)
with self.assertRaises(Exception) as context:
if method == "save":
getattr(storage, method)("test.txt", "content")
else:
getattr(storage, method)("test.txt")
self.assertIn("存储错误", str(context.exception))
def test_init_unsupported_storage_type(self):
"""测试不支持的存储类型抛出 ValueError"""
config = {"storage_type": "unsupported"}
with self.assertRaises(ValueError) as context:
DynamicStorage(config_dict=config)
self.assertIn("不支持的存储类型", str(context.exception))
def test_check_connection_local(self):
"""测试本地存储连接检查"""
storage = DynamicStorage(config_dict=self.local_config)
success, msg = storage.check_connection()
self.assertTrue(success)
self.assertEqual(msg, "本地存储连接成功")
def test_check_connection_sftp(self):
"""测试 SFTP 连接检查"""
with patch("sql.storage.SFTPStorage") as mock_sftp_class:
# 完整的上下文管理器模拟
mock_context = MagicMock()
mock_context.listdir.return_value = ([], [])
mock_sftp_instance = MagicMock()
mock_sftp_instance.__enter__.return_value = mock_context
mock_sftp_instance.__exit__.return_value = None
mock_sftp_class.return_value = mock_sftp_instance
# 成功场景
storage = DynamicStorage(config_dict=self.sftp_config)
success, msg = storage.check_connection()
mock_context.listdir.assert_called_once_with(".")
self.assertTrue(success)
self.assertEqual(msg, "SFTP 连接成功")
# 失败场景
mock_context.listdir.side_effect = Exception("SFTP连接失败")
success, msg = storage.check_connection()
self.assertFalse(success)
self.assertIn("SFTP连接失败", msg)
def test_check_connection_s3c(self):
"""测试 S3 兼容存储连接检查"""
with patch("sql.storage.S3Boto3Storage") as mock_s3_class:
mock_s3_instance = MagicMock()
mock_s3_class.return_value = mock_s3_instance
# 设置 bucket_name 属性
mock_s3_instance.bucket_name = "my-bucket"
# 成功场景
mock_client = MagicMock()
mock_s3_instance.connection.meta.client = mock_client
storage = DynamicStorage(config_dict=self.s3c_config)
success, msg = storage.check_connection()
mock_client.head_bucket.assert_called_once_with(Bucket="my-bucket")
self.assertTrue(success)
self.assertEqual(msg, "S3 存储连接成功")
# 失败场景
mock_client.head_bucket.side_effect = Exception("Bucket 不存在")
success, msg = storage.check_connection()
self.assertFalse(success)
self.assertIn("Bucket 不存在", msg)
def test_check_connection_azure(self):
"""测试 Azure Blob 存储连接检查"""
with patch("sql.storage.AzureStorage") as mock_azure_class:
mock_azure_instance = MagicMock()
mock_azure_class.return_value = mock_azure_instance
# 成功场景
mock_client = MagicMock()
mock_azure_instance.client = mock_client
storage = DynamicStorage(config_dict=self.azure_config)
success, msg = storage.check_connection()
mock_client.get_container_properties.assert_called_once()
self.assertTrue(success)
self.assertEqual(msg, "Azure Blob 存储连接成功")
# 失败场景
mock_client.get_container_properties.side_effect = Exception("容器不存在")
success, msg = storage.check_connection()
self.assertFalse(success)
self.assertIn("容器不存在", msg)
@parameterized.expand(
[
("sftp", "sftp_custom_params", '{"timeout": 30}'),
("s3c", "s3c_custom_params", '{"addressing_style": "virtual"}'),
("azure", "azure_custom_params", '{"max_connections": 10}'),
]
)
def test_custom_json_params(self, storage_type, param_name, json_value):
"""自定义JSON参数处理测试"""
config_dict = getattr(self, f"{storage_type}_config").copy()
config_dict[param_name] = json_value
with patch(f"sql.storage.{self.storage_classes[storage_type]}") as mock_storage:
DynamicStorage(config_dict=config_dict)
call_args = mock_storage.call_args[1]
expected_value = json.loads(json_value)
for key, value in expected_value.items():
self.assertEqual(call_args[key], value)
@parameterized.expand(
[
("sftp", "sftp_custom_params"),
("s3c", "s3c_custom_params"),
("azure", "azure_custom_params"),
]
)
def test_invalid_json_params(self, storage_type, param_name):
"""测试无效JSON参数处理"""
config_dict = getattr(self, f"{storage_type}_config").copy()
config_dict[param_name] = "invalid{json"
with patch(f"sql.storage.{self.storage_classes[storage_type]}") as mock_storage:
try:
storage = DynamicStorage(config_dict=config_dict)
# 验证基础参数仍然正确设置
if storage_type == "sftp":
self.assertEqual(storage.sftp_path, self.sftp_config["sftp_path"])
elif storage_type == "s3c":
self.assertEqual(
storage.s3c_bucket_name, self.s3c_config["s3c_bucket_name"]
)
else:
self.assertEqual(
storage.azure_container, self.azure_config["azure_container"]
)
except json.JSONDecodeError:
self.fail("Invalid JSON should be handled gracefully")
@parameterized.expand(
[
("sftp", "sftp_password", ""),
("s3c", "s3c_access_key_secret", ""),
("azure", "azure_account_key", ""),
]
)
def test_empty_config_values(self, storage_type, param_name, empty_value):
"""测试空配置参数处理"""
config_dict = getattr(self, f"{storage_type}_config").copy()
config_dict[param_name] = empty_value
with patch(f"sql.storage.{self.storage_classes[storage_type]}") as mock_storage:
DynamicStorage(config_dict=config_dict)
call_args = mock_storage.call_args[1]
if storage_type == "sftp":
self.assertEqual(call_args["params"]["password"], "")
elif storage_type == "s3c":
self.assertEqual(call_args["secret_key"], "")
else:
self.assertEqual(call_args["account_key"], "")