import json
from unittest.mock import mock_open, patch, MagicMock
import pytest
from pydantic import BaseModel
from motor.common.etcd import locks
from motor.common.etcd.etcd_client import EtcdClient
from motor.common.etcd.proto import rpc_pb2, rpc_pb2_grpc
from motor.config.etcd import EtcdConfig
from motor.config.tls_config import TLSConfig
@pytest.fixture
def base_client_with_ssl():
"""Create base client with SSL."""
mock_channel = MagicMock()
mock_kv_stub = MagicMock()
mock_lease_stub = MagicMock()
tls_config = TLSConfig(
enable_tls=True,
ca_file="ca_cert_path",
key_file="cert_key_path",
cert_file="cert_cert_path"
)
etcd_config = EtcdConfig(
etcd_host="test_host",
etcd_port=1234,
etcd_timeout=10
)
with patch("grpc.secure_channel", return_value=mock_channel), \
patch.object(rpc_pb2_grpc, "KVStub", return_value=mock_kv_stub), \
patch.object(rpc_pb2_grpc, "LeaseStub", return_value=mock_lease_stub), \
patch('builtins.open', mock_open(read_data=b"mock_content")):
client = EtcdClient(
etcd_config=etcd_config,
tls_config=tls_config
)
return client
def test_init_default_parameters():
"""Test initialization with default parameters."""
etcd_config = EtcdConfig()
client = EtcdClient(etcd_config=etcd_config)
assert client.host == "etcd.default.svc.cluster.local"
assert client.port == 2379
assert client.tls_config is None
assert client.timeout == 5
assert client.channel is not None
assert client.kv_stub is not None
assert client.lease_stub is not None
assert client._leases == {}
def test_init_with_certificates():
"""Test initialization with certificate files."""
mock_channel = MagicMock()
mock_kv_stub = MagicMock()
mock_lease_stub = MagicMock()
tls_config = TLSConfig(
enable_tls=True,
ca_file="ca_cert_path",
key_file="cert_key_path",
cert_file="cert_cert_path"
)
etcd_config = EtcdConfig(
etcd_host="test_host",
etcd_port=1234,
etcd_timeout=10
)
with patch("grpc.secure_channel", return_value=mock_channel), \
patch.object(rpc_pb2_grpc, "KVStub", return_value=mock_kv_stub), \
patch.object(rpc_pb2_grpc, "LeaseStub", return_value=mock_lease_stub), \
patch('builtins.open', mock_open(read_data=b"mock_content")):
client = EtcdClient(
etcd_config=etcd_config,
tls_config=tls_config
)
assert client.host == "test_host"
assert client.port == 1234
assert client.tls_config == tls_config
assert client.tls_config.ca_file == "ca_cert_path"
assert client.tls_config.key_file == "cert_key_path"
assert client.tls_config.cert_file == "cert_cert_path"
assert client.timeout == 10
assert client.channel == mock_channel
assert client.kv_stub == mock_kv_stub
assert client.lease_stub == mock_lease_stub
def test_init_with_missing_certificates():
"""Test initialization with missing certificate files."""
tls_config = TLSConfig(
enable_tls=True,
ca_file="ca_cert_path",
key_file="cert_key_path",
cert_file="cert_cert_path"
)
etcd_config = EtcdConfig(
etcd_host="test_host",
etcd_port=1234,
etcd_timeout=10
)
with patch("builtins.open", side_effect=FileNotFoundError("File not found")):
client = EtcdClient(
etcd_config=etcd_config,
tls_config=tls_config
)
assert client.host == "test_host"
def test_get_key_with_namespace_and_job_name_already_has_namespace(monkeypatch):
test_namespace = "test_namespace"
monkeypatch.setattr("motor.common.etcd.etcd_client.namespace", test_namespace)
job_name = "test_job_name"
monkeypatch.setattr("motor.common.etcd.etcd_client.job_name", job_name)
key = "test_namespace/test_job_name/key"
result = EtcdClient.get_key_with_namespace_and_job_name(key)
assert result == key
def test_get_key_with_namespace_and_job_name_no_prefix_has_namespace(monkeypatch):
test_namespace = "test_namespace"
monkeypatch.setattr("motor.common.etcd.etcd_client.namespace", test_namespace)
key = "key"
result = EtcdClient.get_key_with_namespace_and_job_name(key)
assert result != key
def test_prefix_range_end_normal_case():
"""Test the normal case where the prefix does not end with 0xff."""
prefix = "test"
expected_result = b"tesu"
assert EtcdClient._prefix_range_end(prefix) == expected_result
def test_prefix_range_end_empty_string():
"""Test the case where the prefix is an empty string."""
prefix = ""
expected_result = b""
assert EtcdClient._prefix_range_end(prefix) == expected_result
def test_acquire_lock_success(base_client_with_ssl):
with patch.object(locks, "Lock") as mock_lock_class:
mock_lock = MagicMock()
mock_lock.acquire.return_value = True
mock_lock.lease_id = 12345
mock_uuid_bytes = b"test_uuid_bytes"
mock_lock.uuid = mock_uuid_bytes
mock_lock_class.return_value = mock_lock
result = base_client_with_ssl.acquire_lock("test_key")
assert result is not None
assert result == str(mock_uuid_bytes)
def test_acquire_lock_failure(base_client_with_ssl):
with patch.object(locks, "Lock") as mock_lock_class:
mock_lock = MagicMock()
mock_lock.acquire.return_value = False
mock_lock_class.return_value = mock_lock
result = base_client_with_ssl.acquire_lock("test_key")
assert result is None
assert "test_key" not in base_client_with_ssl._leases
def test_acquire_lock_exception(base_client_with_ssl):
with patch.object(locks, "Lock", side_effect=Exception("Test exception")):
result = base_client_with_ssl.acquire_lock("test_key")
assert result is None
assert "test_key" not in base_client_with_ssl._leases
def test_renew_lease_success(base_client_with_ssl):
lock_key = "test_lock"
lease_id = 123
new_ttl = 60
base_client_with_ssl._leases[lock_key] = lease_id
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=lock_key)
response = MagicMock()
response.TTL = new_ttl
response_stream = iter([response])
base_client_with_ssl.lease_stub.LeaseKeepAlive.return_value = response_stream
result = base_client_with_ssl.renew_lease(lock_key)
assert result is True
base_client_with_ssl.lease_stub.LeaseKeepAlive.assert_called_once()
def test_renew_lease_lock_not_found(base_client_with_ssl):
lock_key = "test_lock"
base_client_with_ssl.get_key_with_namespace_and_job_name.return_value = lock_key
result = base_client_with_ssl.renew_lease(lock_key)
assert result is False
base_client_with_ssl.lease_stub.LeaseKeepAlive.assert_not_called()
def test_renew_lease_expired(base_client_with_ssl):
lock_key = "test_lock"
lease_id = 123
new_ttl = 0
base_client_with_ssl._leases[lock_key] = lease_id
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=lock_key)
response = MagicMock()
response.TTL = new_ttl
response_stream = iter([response])
base_client_with_ssl.lease_stub.LeaseKeepAlive.return_value = response_stream
result = base_client_with_ssl.renew_lease(lock_key)
assert result is False
base_client_with_ssl.lease_stub.LeaseKeepAlive.assert_called_once()
def test_renew_lease_exception(base_client_with_ssl):
lock_key = "test_lock"
lease_id = 123
base_client_with_ssl._leases[lock_key] = lease_id
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=lock_key)
base_client_with_ssl.lease_stub.LeaseKeepAlive.side_effect = Exception("Test exception")
result = base_client_with_ssl.renew_lease(lock_key)
assert result is False
base_client_with_ssl.lease_stub.LeaseKeepAlive.assert_called_once()
def test_release_lock_success(base_client_with_ssl):
"""Test successful lock release."""
lock_key = "test_lock"
lease_id = 12345
base_client_with_ssl._leases[lock_key] = lease_id
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=lock_key)
result = base_client_with_ssl.release_lock(lock_key)
assert result is True
assert lock_key not in base_client_with_ssl._leases
def test_release_lock_not_exist(base_client_with_ssl):
"""Test releasing a non-existent lock."""
lock_key = "non_existent_lock"
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=lock_key)
result = base_client_with_ssl.release_lock(lock_key)
assert result is False
assert lock_key not in base_client_with_ssl._leases
def test_release_lock_exception(base_client_with_ssl):
"""Test exception during lock release."""
lock_key = "test_lock"
lease_id = 12345
base_client_with_ssl._leases[lock_key] = lease_id
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=lock_key)
base_client_with_ssl.lease_stub.LeaseRevoke.side_effect = Exception("Test exception")
result = base_client_with_ssl.release_lock(lock_key)
assert result is False
assert lock_key not in base_client_with_ssl._leases
@pytest.fixture
def mock_logger():
with patch('motor.common.etcd.etcd_client.logger') as mock_logger:
yield mock_logger
def test_put_json_with_pydantic_model(base_client_with_ssl, mock_logger):
class TestModel(BaseModel):
field: str
model = TestModel(field="value")
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value="key")
base_client_with_ssl.put_json("key", model)
base_client_with_ssl.kv_stub.Put.assert_called_once()
mock_logger.info.assert_called_once_with("Stored JSON data for key %s", "key")
def test_put_json_with_dict(base_client_with_ssl, mock_logger):
data = {"field": "value"}
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value="key")
base_client_with_ssl.put_json("key", data)
base_client_with_ssl.kv_stub.Put.assert_called_once_with(
rpc_pb2.PutRequest(
key=b'key',
value=json.dumps(data, ensure_ascii=False).encode('utf-8')
),
timeout=10
)
mock_logger.info.assert_called_once_with("Stored JSON data for key %s", "key")
def test_put_json_with_lease(base_client_with_ssl, mock_logger):
data = {"field": "value"}
lease = 123
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value="key")
base_client_with_ssl.put_json("key", data, lease=lease)
base_client_with_ssl.kv_stub.Put.assert_called_once_with(
rpc_pb2.PutRequest(
key=b'key',
value=json.dumps(data, ensure_ascii=False).encode('utf-8'),
lease=lease
),
timeout=10
)
mock_logger.info.assert_called_once_with("Stored JSON data for key %s", "key")
def test_delete_prefix_success(base_client_with_ssl):
mock_response = MagicMock()
mock_response.deleted = 5
base_client_with_ssl.kv_stub.DeleteRange.return_value = mock_response
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value="test_prefix")
base_client_with_ssl._prefix_range_end = MagicMock(return_value="test_range_end".encode('utf-8'))
result = base_client_with_ssl.delete_prefix("test_prefix")
assert result is True
base_client_with_ssl.kv_stub.DeleteRange.assert_called_once()
base_client_with_ssl.get_key_with_namespace_and_job_name.assert_called_once_with("test_prefix")
base_client_with_ssl._prefix_range_end.assert_called_once_with("test_prefix")
def test_delete_prefix_failure(base_client_with_ssl):
base_client_with_ssl.kv_stub.DeleteRange.side_effect = Exception("Delete failed")
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value="test_prefix")
base_client_with_ssl._prefix_range_end = MagicMock(return_value="test_range_end".encode('utf-8'))
result = base_client_with_ssl.delete_prefix("test_prefix")
assert result is False
base_client_with_ssl.kv_stub.DeleteRange.assert_called_once()
base_client_with_ssl.get_key_with_namespace_and_job_name.assert_called_once_with("test_prefix")
base_client_with_ssl._prefix_range_end.assert_called_once_with("test_prefix")
def test_delete_key_success(base_client_with_ssl):
key = "test_key"
expected_key_with_namespace = "namespace_test_key"
response = rpc_pb2.DeleteRangeResponse(deleted=1)
base_client_with_ssl.kv_stub.DeleteRange.return_value = response
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=expected_key_with_namespace)
result = base_client_with_ssl.delete_key(key)
assert result is True
base_client_with_ssl.kv_stub.DeleteRange.assert_called_once_with(
rpc_pb2.DeleteRangeRequest(key=expected_key_with_namespace.encode('utf-8')),
timeout=base_client_with_ssl.timeout
)
def test_delete_key_not_found(base_client_with_ssl):
key = "test_key"
expected_key_with_namespace = "namespace_test_key"
response = rpc_pb2.DeleteRangeResponse(deleted=0)
base_client_with_ssl.kv_stub.DeleteRange.return_value = response
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=expected_key_with_namespace)
result = base_client_with_ssl.delete_key(key)
assert result is True
base_client_with_ssl.kv_stub.DeleteRange.assert_called_once_with(
rpc_pb2.DeleteRangeRequest(key=expected_key_with_namespace.encode('utf-8')),
timeout=base_client_with_ssl.timeout
)
def test_delete_key_exception(base_client_with_ssl):
key = "test_key"
expected_key_with_namespace = "namespace_test_key"
base_client_with_ssl.kv_stub.DeleteRange.side_effect = Exception("Test exception")
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=expected_key_with_namespace)
result = base_client_with_ssl.delete_key(key)
assert result is False
base_client_with_ssl.kv_stub.DeleteRange.assert_called_once_with(
rpc_pb2.DeleteRangeRequest(key=expected_key_with_namespace.encode('utf-8')),
timeout=base_client_with_ssl.timeout
)
def test_lock_context_success(base_client_with_ssl):
"""Test successful lock acquisition and release using context manager."""
lease_id = "lease123"
base_client_with_ssl.acquire_lock = MagicMock(return_value=lease_id)
base_client_with_ssl.release_lock = MagicMock(return_value=True)
with base_client_with_ssl.lock_context("test_lock") as acquired_lease_id:
assert acquired_lease_id == lease_id
base_client_with_ssl.release_lock.assert_called_once_with("test_lock")
def test_lock_context_acquire_failure(base_client_with_ssl):
"""Test failure to acquire lock using context manager."""
base_client_with_ssl.acquire_lock = MagicMock(return_value=None)
with pytest.raises(RuntimeError, match="Failed to acquire lock test_lock"):
with base_client_with_ssl.lock_context("test_lock"):
pass
def test_close_success(base_client_with_ssl):
"""Test that the close method successfully closes the channel."""
mock_channel = MagicMock()
mock_lock = MagicMock()
base_client_with_ssl.channel = mock_channel
base_client_with_ssl._lock = mock_lock
base_client_with_ssl.close()
mock_channel.close.assert_called_once()
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
def test_close_with_exception(base_client_with_ssl):
"""Test that the close method handles exceptions properly."""
mock_channel = MagicMock()
mock_lock = MagicMock()
base_client_with_ssl.channel = mock_channel
base_client_with_ssl._lock = mock_lock
mock_channel.close.side_effect = Exception("Test exception")
base_client_with_ssl.close()
mock_channel.close.assert_called_once()
mock_lock.__enter__.assert_called_once()
mock_lock.__exit__.assert_called_once()
def test_persist_data_success(base_client_with_ssl):
mock_lock_context = MagicMock()
mock_lock_context.return_value.__enter__ = MagicMock(return_value=True)
mock_lock_context.return_value.__exit__ = MagicMock()
base_client_with_ssl.lock_context = mock_lock_context
base_client_with_ssl.delete_prefix = MagicMock(return_value=True)
base_client_with_ssl.put_json = MagicMock(return_value=True)
key_prefix = "test_prefix"
data = {"key1": {"value": "value1"}, "key2": {"value": "value2"}}
result = base_client_with_ssl.persist_data(key_prefix, data)
assert result is True
base_client_with_ssl.lock_context.assert_called_once_with(f"persist_{key_prefix}", ttl=30)
base_client_with_ssl.delete_prefix.assert_called_once_with(key_prefix)
assert base_client_with_ssl.put_json.call_count == len(data)
for data_key in data.keys():
base_client_with_ssl.put_json.assert_any_call(f"{key_prefix}/{data_key}", data[data_key])
def test_persist_data_lock_failure(base_client_with_ssl):
base_client_with_ssl.lock_context = MagicMock(side_effect=Exception("Lock failed"))
base_client_with_ssl.delete_prefix = MagicMock(return_value=True)
base_client_with_ssl.put_json = MagicMock(return_value=True)
key_prefix = "test_prefix"
data = {"key1": {"value": "value1"}, "key2": {"value": "value2"}}
result = base_client_with_ssl.persist_data(key_prefix, data)
assert result is False
base_client_with_ssl.lock_context.assert_called_once_with(f"persist_{key_prefix}", ttl=30)
base_client_with_ssl.delete_prefix.assert_not_called()
base_client_with_ssl.put_json.assert_not_called()
def test_persist_data_put_json_failure(base_client_with_ssl):
mock_lock_context = MagicMock()
mock_lock_context.return_value.__enter__ = MagicMock(return_value=True)
mock_lock_context.return_value.__exit__ = MagicMock()
base_client_with_ssl.lock_context = mock_lock_context
base_client_with_ssl.delete_prefix = MagicMock(return_value=True)
base_client_with_ssl.put_json = MagicMock(side_effect=[True, False])
key_prefix = "test_prefix"
data = {"key1": {"value": "value1"}, "key2": {"value": "value2"}}
result = base_client_with_ssl.persist_data(key_prefix, data)
assert result is False
base_client_with_ssl.lock_context.assert_called_once_with(f"persist_{key_prefix}", ttl=30)
base_client_with_ssl.delete_prefix.assert_called_once_with(key_prefix)
assert base_client_with_ssl.put_json.call_count == len(data)
for data_key in data.keys():
base_client_with_ssl.put_json.assert_any_call(f"{key_prefix}/{data_key}", data[data_key])
def test_restore_data_success(base_client_with_ssl):
"""Test successful data restoration."""
class TestRestoreModel(BaseModel):
field1: str
field2: int
test_data = {
"key1": {"field1": "value1", "field2": 1},
"key2": {"field1": "value2", "field2": 2}
}
with patch.object(EtcdClient, 'get_prefix_data', return_value=test_data):
result = base_client_with_ssl.restore_data("test_prefix", TestRestoreModel)
assert result == test_data
def test_restore_data_exception(base_client_with_ssl):
"""Test exception during data restoration."""
class TestRestoreModel(BaseModel):
field1: str
field2: int
with patch.object(EtcdClient, 'get_prefix_data', side_effect=Exception("Test exception")):
result = base_client_with_ssl.restore_data("test_prefix", TestRestoreModel)
assert result is None
def test_get_prefix_data_success(base_client_with_ssl):
key_prefix = "test_prefix"
mock_response = MagicMock()
mock_response.kvs = [
MagicMock(key=b"test_prefix/key1", value=json.dumps({"key1": "value1"}).encode("utf-8")),
MagicMock(key=b"test_prefix/key2", value=json.dumps({"key2": "value2"}).encode("utf-8")),
]
base_client_with_ssl.kv_stub.Range.return_value = mock_response
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=key_prefix)
base_client_with_ssl._prefix_range_end.return_value = b"test_prefix0"
result = base_client_with_ssl.get_prefix_data(key_prefix)
assert result == {
"key1": {"key1": "value1"},
"key2": {"key2": "value2"},
}
base_client_with_ssl.kv_stub.Range.assert_called_once()
def test_get_prefix_data_with_model_class(base_client_with_ssl):
class TestModel(BaseModel):
key: str
key_prefix = "test_prefix"
mock_response = MagicMock()
mock_response.kvs = [
MagicMock(key=b"test_prefix/key1", value=json.dumps({"key": "value1"}).encode("utf-8")),
]
base_client_with_ssl.kv_stub.Range.return_value = mock_response
base_client_with_ssl.get_key_with_namespace_and_job_name = MagicMock(return_value=key_prefix)
base_client_with_ssl._prefix_range_end.return_value = b"test_prefix0"
result = base_client_with_ssl.get_prefix_data(key_prefix, model_class=TestModel)
assert isinstance(result["key1"], TestModel)
assert result["key1"].key == "value1"
def test_get_prefix_data_exception(base_client_with_ssl):
key_prefix = "test_prefix"
base_client_with_ssl.kv_stub.Range.side_effect = Exception("Test exception")
result = base_client_with_ssl.get_prefix_data(key_prefix)
assert result == {}