import time
from unittest.mock import patch, MagicMock
import pytest
from fastapi import HTTPException
from motor.controller.core.instance_manager import InstanceManager, PersistentState
from motor.common.resources.endpoint import Endpoint, EndpointStatus
from motor.common.resources.http_msg_spec import HeartbeatMsg
from motor.common.resources.instance import (
ParallelConfig,
Instance,
NodeManagerInfo,
InsStatus,
InsConditionEvent,
ReadOnlyInstance,
)
from motor.common.resources import EventType
from motor.common.utils.singleton import ThreadSafeSingleton
from motor.config.controller import ControllerConfig
from motor.controller.core.event_pusher import EventPusher
def create_test_instance(instance_id: int, job_name: str, pod_ips: list[str], role: str = "prefill") -> Instance:
"""Helper function to create test instances with endpoints"""
endpoints = {}
for i, pod_ip in enumerate(pod_ips):
endpoints[pod_ip] = {
0: Endpoint(
id=0,
ip=pod_ip,
business_port=f"80{0}{i}",
mgmt_port=f"80{1}{i}",
status=EndpointStatus.NORMAL,
hb_timestamp=time.time(),
)
}
return Instance(id=instance_id, job_name=job_name, model_name="test_model", role=role, endpoints=endpoints)
def _create_endpoint(endpoint_id: int, ip: str, business_port: str = "9090", mgmt_port: str = "8080") -> Endpoint:
"""Helper function to create an Endpoint with default values"""
return Endpoint(
id=endpoint_id,
ip=ip,
business_port=business_port,
mgmt_port=mgmt_port,
status=EndpointStatus.INITIAL,
device_infos=[],
hb_timestamp=time.time(),
)
def get_mock_heartbeat_msg(job_name: str, ins_id: int, ip: str, status_dict: dict = None) -> HeartbeatMsg:
"""Generate a mock heartbeat message with configurable status"""
if status_dict is None:
status_dict = {0: EndpointStatus.NORMAL}
return HeartbeatMsg(job_name=job_name, ins_id=ins_id, ip=ip, status=status_dict)
def create_instance_manager_with_config(enable_etcd=False) -> InstanceManager:
"""Create instance manager with specific config"""
config = ControllerConfig()
config.etcd_config.enable_etcd_persistence = enable_etcd
config.instance_manager_check_interval = 0.1
return InstanceManager(config)
@pytest.fixture
def test_config():
"""Test configuration fixture"""
dp = 8
tp = 2
p_role = "prefill"
d_role = "decode"
pod_ips = [f"127.0.0.{i}" for i in range(1, 9)]
p_parallel_config = ParallelConfig(dp_size=dp, tp_size=tp)
d_parallel_config = ParallelConfig(dp_size=dp * 4, tp_size=tp // 2)
return {
'dp': dp,
'tp': tp,
'p_role': p_role,
'd_role': d_role,
'pod_ips': pod_ips,
'p_parallel_config': p_parallel_config,
'd_parallel_config': d_parallel_config,
}
@pytest.fixture(autouse=True)
def mock_etcd_client():
"""Mock EtcdClient to avoid real ETCD operations in tests"""
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.persist_data.return_value = True
mock_client.restore_data.return_value = None
mock_etcd_class.return_value = mock_client
yield mock_client
@pytest.fixture(autouse=True)
def setup_test_environment():
"""Setup and teardown for each test"""
if hasattr(ThreadSafeSingleton, '_instances') and InstanceManager in ThreadSafeSingleton._instances:
try:
ThreadSafeSingleton._instances[InstanceManager].stop()
except Exception:
pass
del ThreadSafeSingleton._instances[InstanceManager]
@pytest.fixture
def instance_manager(test_config):
"""Setup mock instance manager with test instances"""
instance_manager = create_instance_manager_with_config()
pod_ips = test_config['pod_ips']
instance_manager.add_instance(
Instance(
job_name="prefill-0",
model_name="test_model",
id=0,
role=test_config['p_role'],
parallel_config=test_config['p_parallel_config'],
node_mgrs=[
NodeManagerInfo(pod_ip=pod_ips[0], host_ip=pod_ips[0], port="8080"),
NodeManagerInfo(pod_ip=pod_ips[1], host_ip=pod_ips[1], port="8080"),
],
endpoints={
pod_ips[0]: {0: _create_endpoint(0, pod_ips[0])},
pod_ips[1]: {0: _create_endpoint(0, pod_ips[1])},
},
)
)
instance_manager.add_instance(
Instance(
job_name="prefill-1",
model_name="test_model",
id=1,
role=test_config['p_role'],
parallel_config=test_config['p_parallel_config'],
node_mgrs=[
NodeManagerInfo(pod_ip=pod_ips[2], host_ip=pod_ips[2], port="8080"),
NodeManagerInfo(pod_ip=pod_ips[3], host_ip=pod_ips[3], port="8080"),
],
endpoints={
pod_ips[2]: {0: _create_endpoint(0, pod_ips[2])},
pod_ips[3]: {0: _create_endpoint(0, pod_ips[3])},
},
)
)
d_instance = Instance(
job_name="decode-0",
model_name="test_model",
id=2,
role=test_config['d_role'],
parallel_config=test_config['d_parallel_config'],
node_mgrs=[
NodeManagerInfo(pod_ip=pod_ips[4], host_ip=pod_ips[4], port="8080"),
NodeManagerInfo(pod_ip=pod_ips[5], host_ip=pod_ips[5], port="8080"),
NodeManagerInfo(pod_ip=pod_ips[6], host_ip=pod_ips[6], port="8080"),
NodeManagerInfo(pod_ip=pod_ips[7], host_ip=pod_ips[7], port="8080"),
],
endpoints={},
)
endpoints = {}
for pod_ip in pod_ips[4:8]:
port_temp = 8080
endpoints[pod_ip] = {}
for i in range(0, 8):
endpoints[pod_ip][i] = _create_endpoint(
endpoint_id=i, ip=pod_ip, business_port=str(port_temp), mgmt_port=str(port_temp + 1000)
)
port_temp += 1
d_instance.add_endpoints(pod_ip, endpoints[pod_ip])
instance_manager.add_instance(d_instance)
return instance_manager
def test_singleton_initialization():
"""Test InstanceManager singleton initialization"""
manager1 = InstanceManager()
assert manager1 is not None
assert hasattr(manager1, '_initialized')
manager2 = InstanceManager()
assert manager1 is manager2
def test_initialization_with_config():
"""Test initialization with custom config"""
config = ControllerConfig()
config.etcd_config.enable_etcd_persistence = True
manager = InstanceManager(config)
assert manager.etcd_config is config.etcd_config
assert manager.instance_manager_check_interval == config.instance_config.instance_manager_check_interval
@patch('motor.controller.core.instance_manager.time.sleep')
def test_start_stop_manager(mock_sleep):
"""Test starting and stopping the instance manager"""
manager = create_instance_manager_with_config()
mock_sleep.return_value = None
manager.start()
assert manager.instances_management_thread is not None
assert manager.instances_management_thread.is_alive()
assert not manager.stop_event.is_set()
manager.stop()
assert manager.stop_event.is_set()
if manager.instances_management_thread and manager.instances_management_thread.is_alive():
manager.instances_management_thread.join(timeout=0.05)
def test_persist_data_success():
"""Test successful data persistence"""
manager = create_instance_manager_with_config(enable_etcd=True)
instance = create_test_instance(1, "test_job", ["192.168.1.1"])
manager.add_instance(instance)
result = manager.persist_data()
assert result is True
def test_persist_data_failure():
"""Test data persistence failure"""
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.persist_data.side_effect = Exception("ETCD error")
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
instance = create_test_instance(1, "test_job", ["192.168.1.1"])
manager.add_instance(instance)
result = manager.persist_data()
assert result is False
def test_restore_data_success():
"""Test successful data restoration"""
instance_data = {
"id": 1,
"job_name": "test_job",
"model_name": "test_model",
"role": "prefill",
"endpoints": {},
"status": "initial",
"parallel_config": None,
"node_managers": [],
"gathered_workload": {"active_kv_cache": 0, "active_tokens": 0},
}
persistent_state = PersistentState(data={"1": instance_data}, version=1, timestamp=time.time(), checksum="")
persistent_state.checksum = persistent_state.calculate_checksum()
mock_persistent_states = {"state": persistent_state}
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
mock_event_pusher = MagicMock(spec=EventPusher)
manager.attach(mock_event_pusher)
result = manager.restore_data()
assert result is True
assert 1 in manager.instances
mock_event_pusher.push_event.assert_called_once_with(EventType.SET)
def test_restore_data_no_data():
"""Test restoration when no data exists"""
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = None
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
result = manager.restore_data()
assert result is True
def test_restore_data_invalid_checksum():
"""Test restoration with invalid checksum"""
mock_persistent_states = {
"state": PersistentState(
data={
"1": {
"id": 1,
"job_name": "test_job",
"model_name": "test_model",
"role": "prefill",
"endpoints": {},
"status": "initial",
"parallel_config": None,
"node_managers": [],
"gathered_workload": {"active_kv_cache": 0, "active_tokens": 0},
}
},
version=1,
timestamp=time.time(),
checksum="invalid_checksum",
)
}
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
result = manager.restore_data()
assert result is False
assert 1 not in manager.instances
def test_add_instance(instance_manager, test_config):
"""Test adding an instance"""
cur_instance_num = instance_manager.get_instance_num()
instance_manager.add_instance(None)
assert instance_manager.get_instance_num() == cur_instance_num
instance_manager.add_instance("invalid_instance")
assert instance_manager.get_instance_num() == cur_instance_num
instance_manager.add_instance(
Instance(
job_name="testAllocInsGroup2",
model_name="test_model",
id=100,
role=test_config['p_role'],
parallel_config=ParallelConfig(dp_size=test_config['dp'], tp_size=test_config['tp'] // 2),
)
)
assert instance_manager.get_instance_num() == cur_instance_num + 1
instance_manager.add_instance(
Instance(
job_name="testAllocInsGroup2",
model_name="test_model",
id=100,
role=test_config['p_role'],
parallel_config=ParallelConfig(dp_size=test_config['dp'], tp_size=test_config['tp'] // 2),
)
)
assert instance_manager.get_instance_num() == cur_instance_num + 1
def test_del_instance(instance_manager):
"""Test deleting an instance"""
cur_instance_num = instance_manager.get_instance_num()
instance_manager.del_instance(0)
assert instance_manager.get_instance_num() == cur_instance_num - 1
instance_manager.del_instance(999)
assert instance_manager.get_instance_num() == cur_instance_num - 1
def test_get_instance(instance_manager):
"""Test getting instances"""
instance = instance_manager.get_instance(1)
assert instance is not None
assert instance.id == 1
instance = instance_manager.get_instance(999)
assert instance is None
def test_get_instance_num(instance_manager):
"""Test getting instance count"""
count = instance_manager.get_instance_num()
assert count == 3
def test_get_active_instances(instance_manager):
"""Test getting active instances"""
active_instances = instance_manager.get_active_instances()
assert len(active_instances) == 0
instance = instance_manager.get_instance(0)
instance.status = InsStatus.ACTIVE
active_instances = instance_manager.get_active_instances()
assert len(active_instances) == 1
assert active_instances[0].id == 0
def test_get_inactive_instances(instance_manager):
"""Test getting inactive instances"""
instance = instance_manager.get_instance(0)
instance.status = InsStatus.INACTIVE
inactive_instances = instance_manager.get_inactive_instances()
assert len(inactive_instances) == 1
assert inactive_instances[0].id == 0
def test_get_initial_instances(instance_manager):
"""Test getting initial instances"""
initial_instances = instance_manager.get_initial_instances()
assert len(initial_instances) == 3
def test_get_instance_by_podip(instance_manager):
"""Test getting instance by pod IP"""
result = instance_manager.get_instance_by_podip("127.0.0.1")
assert result is not None
result = instance_manager.get_instance_by_podip("192.168.1.100")
assert result is None
result = instance_manager.get_instance_by_podip("")
assert result is None
def test_has_instance_by_job_name(instance_manager):
"""Test checking if instance exists by job name"""
assert instance_manager.has_instance_by_job_name("prefill-0") is True
assert instance_manager.has_instance_by_job_name("non-existent") is False
def test_get_instance_by_job_name(instance_manager):
"""Test retrieving the current instance by job name"""
d_instance = instance_manager.get_instance_by_job_name("decode-0")
assert d_instance is not None
assert d_instance.job_name == "decode-0"
assert d_instance.id == 2
assert instance_manager.get_instance_by_job_name("non-existent") is None
def test_get_instance_by_job_name_returns_newest_when_stale_entries_exist(instance_manager):
"""Stale and current entries may coexist briefly; return the newest id."""
replacement = create_test_instance(
instance_id=99,
job_name="decode-0",
pod_ips=["10.0.0.99"],
role="decode",
)
instance_manager.add_instance(replacement)
current = instance_manager.get_instance_by_job_name("decode-0")
assert current is not None
assert current.id == 99
def test_handle_heartbeat_success(instance_manager, test_config):
"""Test successful heartbeat handling"""
pod_ips = test_config['pod_ips']
heartbeat_msg = get_mock_heartbeat_msg("prefill-0", 0, pod_ips[0])
success, code = instance_manager.handle_heartbeat(heartbeat_msg)
assert success is True
assert code == 200
instance = instance_manager.get_instance(0)
assert instance.status == InsStatus.INITIAL
def test_handle_heartbeat_invalid_message(instance_manager):
"""Test heartbeat handling with invalid message"""
success, code = instance_manager.handle_heartbeat(None)
assert success is False
assert code == 500
success, code = instance_manager.handle_heartbeat("invalid_message")
assert success is False
assert code == 500
def test_handle_heartbeat_nonexistent_instance():
"""Test heartbeat handling for non-existent instance"""
manager = create_instance_manager_with_config()
heartbeat_msg = get_mock_heartbeat_msg("non-existent", 999, "192.168.1.1")
with pytest.raises(HTTPException) as exc_info:
manager.handle_heartbeat(heartbeat_msg)
assert exc_info.value.status_code == 503
def test_state_transitions(instance_manager, test_config):
"""Test various state transitions"""
pod_ips = test_config['pod_ips']
instance = instance_manager.get_instance(0)
heartbeat_msg = get_mock_heartbeat_msg("prefill-0", 0, pod_ips[0])
instance_manager.handle_heartbeat(heartbeat_msg)
assert instance.status == InsStatus.INITIAL
for endpoints in instance.endpoints.values():
for endpoint in endpoints.values():
endpoint.status = EndpointStatus.NORMAL
heartbeat_msg2 = get_mock_heartbeat_msg("prefill-0", 0, pod_ips[1])
instance_manager.handle_heartbeat(heartbeat_msg2)
assert instance.status == InsStatus.ACTIVE
heartbeat_msg3 = get_mock_heartbeat_msg("prefill-0", 0, pod_ips[1], {0: EndpointStatus.ABNORMAL})
instance_manager.handle_heartbeat(heartbeat_msg3)
assert instance.status == InsStatus.INACTIVE
def test_separate_instance(instance_manager):
"""Test separating instances"""
instance_manager.etcd_config.enable_etcd_persistence = True
instance = create_test_instance(100, "test_separate", ["192.168.1.1"])
instance_manager.add_instance(instance)
instance.update_instance_status(InsStatus.ACTIVE)
with patch.object(instance_manager, 'persist_data', return_value=True) as mock_persist:
instance_manager.separate_instance(instance.id)
assert instance.status == InsStatus.INACTIVE
assert instance.id in instance_manager.forced_separated_instances
mock_persist.assert_called_once()
with patch.object(instance_manager, 'persist_data', return_value=True) as mock_persist:
original_status = instance.status
instance_manager.separate_instance(instance.id)
assert instance.status == original_status
assert instance.id in instance_manager.forced_separated_instances
mock_persist.assert_not_called()
def test_separate_nonexistent_instance(instance_manager):
"""Test separating non-existent instance"""
instance_manager.separate_instance(999)
def test_recover_instance(instance_manager):
"""Test recovering instances"""
instance = create_test_instance(101, "test_recover", ["192.168.1.2"])
instance_manager.add_instance(instance)
instance.update_instance_status(InsStatus.ACTIVE)
instance_manager.separate_instance(instance.id)
assert instance.id in instance_manager.forced_separated_instances
instance_manager.recover_instance(instance.id)
assert instance.id not in instance_manager.forced_separated_instances
def test_recover_nonexistent_instance(instance_manager):
"""Test recovering non-existent instance"""
instance_manager.recover_instance(999)
def test_observer_pattern(instance_manager):
"""Test observer pattern functionality"""
from motor.controller.core import Observer, ObserverEvent
class MockObserver(Observer):
def __init__(self):
self.notifications = []
def update(self, instance: ReadOnlyInstance, event: ObserverEvent):
self.notifications.append((instance.id, event))
observer = MockObserver()
instance_manager.attach(observer)
instance = create_test_instance(102, "test_observer", ["192.168.1.3"])
instance_manager.add_instance(instance)
assert len(observer.notifications) == 1
assert observer.notifications[0] == (102, ObserverEvent.INSTANCE_INITIAL)
observer.notifications.clear()
instance_manager.notify(instance, ObserverEvent.INSTANCE_READY)
assert len(observer.notifications) == 1
assert observer.notifications[0] == (102, ObserverEvent.INSTANCE_READY)
def test_handle_initial_state():
"""Test _handle_initial method"""
manager = create_instance_manager_with_config()
instance = create_test_instance(1, "test_initial", ["192.168.1.1"])
manager.add_instance(instance)
instance.update_instance_status(InsStatus.INACTIVE)
manager._handle_initial(InsStatus.INACTIVE, InsConditionEvent.INSTANCE_INIT, instance)
assert instance.id not in manager.forced_separated_instances
def test_handle_active_state(instance_manager):
"""Test _handle_active method"""
instance = instance_manager.get_instance(0)
instance.update_instance_status(InsStatus.INITIAL)
instance_manager._handle_active(InsStatus.INITIAL, InsConditionEvent.INSTANCE_NORMAL, instance)
assert instance.status == InsStatus.ACTIVE
def test_handle_inactive_state(instance_manager):
"""Test _handle_inactive method"""
instance = instance_manager.get_instance(0)
instance.update_instance_status(InsStatus.ACTIVE)
instance_manager._handle_inactive(InsStatus.ACTIVE, InsConditionEvent.INSTANCE_ABNORMAL, instance)
assert instance.status == InsStatus.INACTIVE
def test_handle_deleted_state(instance_manager):
"""Test _handle_deleted method"""
instance = create_test_instance(103, "test_deleted", ["192.168.1.4"])
instance_manager.add_instance(instance)
instance.update_instance_status(InsStatus.INACTIVE)
instance_manager._handle_deleted(InsStatus.INACTIVE, InsConditionEvent.INSTANCE_HEARTBEAT_TIMEOUT, instance)
assert instance.status == InsStatus.DELETED
assert instance_manager.get_instance(103) is None
def test_refresh_instance_heartbeat(instance_manager):
"""Test heartbeat timestamp refresh"""
instance = instance_manager.get_instance(0)
original_timestamp = time.time() - 100
for endpoints in instance.endpoints.values():
for endpoint in endpoints.values():
endpoint.hb_timestamp = original_timestamp
current_time = time.time()
try:
for endpoints in instance.endpoints.values():
for endpoint in endpoints.values():
endpoint.hb_timestamp = current_time
except Exception:
pass
for endpoints in instance.endpoints.values():
for endpoint in endpoints.values():
assert endpoint.hb_timestamp == current_time
def test_version_control():
"""Test version control functionality"""
manager = create_instance_manager_with_config()
assert manager._data_version == 0
version1 = manager._get_next_version()
assert version1 == 1
assert manager._data_version == 1
version2 = manager._get_next_version()
assert version2 == 2
assert manager._data_version == 2
def test_checksum_calculation(instance_manager):
"""Test instance checksum calculation"""
instance = instance_manager.get_instance(0)
instance_data = instance.model_dump()
state = PersistentState(data=instance_data, version=1, timestamp=time.time(), checksum="")
checksum1 = state.calculate_checksum()
assert isinstance(checksum1, str)
assert len(checksum1) > 0
instance2 = create_test_instance(999, "different_job", ["192.168.1.99"])
instance_data2 = instance2.model_dump()
state2 = PersistentState(data=instance_data2, version=1, timestamp=time.time(), checksum="")
checksum2 = state2.calculate_checksum()
assert checksum1 != checksum2
def test_persistent_instance_state():
"""Test PersistentState functionality"""
instance_data = {"id": 1, "job_name": "test"}
version = 1
timestamp = time.time()
state = PersistentState(
data=instance_data,
version=version,
timestamp=timestamp,
checksum="",
)
state.checksum = state.calculate_checksum()
assert state.is_valid()
state.checksum = "invalid"
assert not state.is_valid()
def test_forced_separation_cleanup(instance_manager):
"""Test forced separated instances cleanup"""
instance = create_test_instance(104, "test_cleanup", ["192.168.1.5"])
instance_manager.add_instance(instance)
instance_manager.separate_instance(instance.id)
assert instance.id in instance_manager.forced_separated_instances
instance_manager.del_instance(instance.id)
assert instance.id not in instance_manager.forced_separated_instances
def test_instances_management_loop_timeout():
"""Test instances management loop timeout handling"""
manager = create_instance_manager_with_config()
instance = create_test_instance(105, "test_timeout", ["192.168.1.6"])
manager.add_instance(instance)
instance.update_instance_status(InsStatus.ACTIVE)
for endpoints in instance.endpoints.values():
for endpoint in endpoints.values():
endpoint.hb_timestamp = time.time() - 1000
from_state = instance.status
event = InsConditionEvent.INSTANCE_HEARTBEAT_TIMEOUT
to_state = manager.transitions.get((from_state, event), None)
if to_state:
state_handler = manager.states.get(to_state, None)
if state_handler:
state_handler(from_state, event, instance)
assert instance.status == InsStatus.INACTIVE
def test_persistence_on_state_change():
"""Test automatic persistence on state changes"""
with patch.object(InstanceManager, 'persist_data') as mock_persist:
manager = create_instance_manager_with_config(enable_etcd=True)
instance = create_test_instance(106, "test_persist", ["192.168.1.7"])
manager.add_instance(instance)
instance.update_instance_status(InsStatus.INITIAL)
heartbeat_msg = get_mock_heartbeat_msg("test_persist", 106, "192.168.1.7")
manager.handle_heartbeat(heartbeat_msg)
mock_persist.assert_called()
def test_prevent_forced_separation_reactivation():
"""Test that forcibly separated instances cannot reactivate to ACTIVE"""
manager = create_instance_manager_with_config()
instance = create_test_instance(107, "test_prevent", ["192.168.1.8"])
manager.add_instance(instance)
instance.update_instance_status(InsStatus.ACTIVE)
manager.separate_instance(instance.id)
assert instance.status == InsStatus.INACTIVE
assert instance.id in manager.forced_separated_instances
for endpoints in instance.endpoints.values():
for endpoint in endpoints.values():
endpoint.status = EndpointStatus.NORMAL
result = manager._handle_state_transition(instance)
assert result is True
assert instance.status == InsStatus.INACTIVE
def test_update_config():
"""Test update_config method updates configuration and recreates ETCD client"""
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
new_config = ControllerConfig()
new_config.etcd_config.etcd_host = "new-etcd-host"
new_config.etcd_config.etcd_port = 2380
new_config.etcd_config.etcd_timeout = 30.0
new_config.etcd_config.enable_etcd_persistence = True
mock_etcd_class.reset_mock()
manager.update_config(new_config)
assert manager.etcd_config is new_config.etcd_config
assert manager.etcd_config.etcd_host == "new-etcd-host"
assert manager.etcd_config.etcd_port == 2380
assert manager.etcd_config.etcd_timeout == 30.0
mock_etcd_class.assert_called_once_with(
etcd_config=new_config.etcd_config, tls_config=new_config.etcd_tls_config
)
def test_persist_and_restore_instance_data_success():
"""Test successful persist and restore of instance manager data"""
manager = create_instance_manager_with_config(enable_etcd=True)
instance = create_test_instance(201, "test_persist_instance", ["192.168.1.1"])
manager.add_instance(instance)
with patch.object(manager.etcd_client, 'persist_data', return_value=True) as mock_persist:
with patch.object(manager.etcd_client, 'restore_data') as mock_restore:
persist_result = manager.persist_data()
assert persist_result
mock_persist.assert_called_once()
args, kwargs = mock_persist.call_args
assert "/controller/instance_manager" in args[0]
instance_data = instance.model_dump()
instance_state = PersistentState(data={"201": instance_data}, version=1, timestamp=time.time(), checksum="")
instance_state.checksum = instance_state.calculate_checksum()
mock_persistent_states = {"state": instance_state}
mock_restore.return_value = mock_persistent_states
with patch('motor.controller.core.instance_manager.EtcdClient'):
new_manager = create_instance_manager_with_config(enable_etcd=True)
restore_result = new_manager.restore_data()
assert restore_result
assert 201 in new_manager.instances
restored_instance = new_manager.instances[201]
assert restored_instance.job_name == instance.job_name
assert restored_instance.id == instance.id
def test_persist_data_with_checksum_validation():
"""Test that persisted data includes correct checksums"""
manager = create_instance_manager_with_config(enable_etcd=True)
instance = create_test_instance(202, "test_checksum", ["192.168.1.2"])
manager.add_instance(instance)
with patch.object(manager.etcd_client, 'persist_data', return_value=True) as mock_persist:
result = manager.persist_data()
assert result
args, kwargs = mock_persist.call_args
persisted_data = args[1]
assert "state" in persisted_data
state_data = persisted_data["state"]
assert "checksum" in state_data
assert len(state_data["checksum"]) > 0
assert "202" in state_data["data"]
state = PersistentState(**state_data)
assert state.is_valid()
def test_restore_data_with_invalid_checksum():
"""Test restore skips data with invalid checksums"""
manager = create_instance_manager_with_config(enable_etcd=True)
mock_persistent_states = {
"state": PersistentState(
data={
"203": {
"id": 203,
"job_name": "test_invalid",
"model_name": "test_model",
"role": "prefill",
"endpoints": {},
"status": "initial",
}
},
version=1,
timestamp=time.time(),
checksum="invalid_checksum",
)
}
with patch.object(manager.etcd_client, 'restore_data', return_value=mock_persistent_states):
result = manager.restore_data()
assert not result
assert 203 not in manager.instances
def test_persistence_disabled_in_config():
"""Test that persistence is properly disabled when config flag is False"""
manager = create_instance_manager_with_config(enable_etcd=False)
instance = create_test_instance(204, "test_disabled", ["192.168.1.3"])
manager.add_instance(instance)
with patch.object(manager.etcd_client, 'persist_data', return_value=True):
result = manager.persist_data()
assert result
def test_persist_empty_instances():
"""Test persisting when no instances exist"""
manager = create_instance_manager_with_config(enable_etcd=True)
with patch.object(manager.etcd_client, 'persist_data', return_value=True) as mock_persist:
result = manager.persist_data()
assert result
args, kwargs = mock_persist.call_args
persisted_data = args[1]
assert "state" in persisted_data
state_data = persisted_data["state"]
assert "data" in state_data
assert len(state_data["data"]) == 0
def test_restore_no_instance_data_available():
"""Test restore when no instance data is available in ETCD"""
manager = create_instance_manager_with_config(enable_etcd=True)
with patch.object(manager.etcd_client, 'restore_data', return_value=None):
result = manager.restore_data()
assert result
assert len(manager.instances) == 0
def test_persistent_state_is_valid_method():
"""Test PersistentState.is_valid method"""
instance_data = {
"id": 205,
"job_name": "test_valid",
"model_name": "test_model",
"role": "prefill",
"endpoints": {},
"status": "active",
}
valid_state = PersistentState(
data=instance_data,
version=1,
timestamp=time.time(),
checksum="",
)
valid_state.checksum = valid_state.calculate_checksum()
assert valid_state.is_valid()
invalid_state = PersistentState(data=instance_data, version=1, timestamp=time.time(), checksum="wrong_checksum")
assert not invalid_state.is_valid()
def test_restore_data_with_type_conversion():
"""Test restoration with string-formatted data from ETCD (type conversion)"""
etcd_string_data = {
"id": "206",
"job_name": "test_type_conversion",
"model_name": "test_model",
"role": "prefill",
"status": "active",
"endpoints": {},
"parallel_config": None,
"node_managers": [],
"gathered_workload": {"memory_mb": "1024", "cpu_cores": "2"},
}
persistent_state = PersistentState(
data={"206": etcd_string_data},
version=1,
timestamp=time.time(),
checksum="",
)
persistent_state.checksum = persistent_state.calculate_checksum()
mock_persistent_states = {"state": persistent_state}
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
result = manager.restore_data()
assert result
assert 206 in manager.instances
instance = manager.instances[206]
assert instance.id == 206
assert instance.job_name == "test_type_conversion"
assert instance.status == InsStatus.ACTIVE
def test_restore_data_with_invalid_enum_value():
"""Test restoration fails gracefully with invalid enum values"""
corrupted_data = {
"id": "207",
"job_name": "test_invalid_enum",
"model_name": "test_model",
"role": "prefill",
"status": "INVALID_STATUS",
"endpoints": {},
"parallel_config": None,
"node_managers": [],
"gathered_workload": {"memory_mb": "1024", "cpu_cores": "2"},
}
persistent_state = PersistentState(
data={"207": corrupted_data},
version=1,
timestamp=time.time(),
checksum="",
)
persistent_state.checksum = persistent_state.calculate_checksum()
mock_persistent_states = {"state": persistent_state}
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
result = manager.restore_data()
assert result
assert 207 not in manager.instances
def test_restore_data_with_malformed_numeric_data():
"""Test restoration fails gracefully with malformed numeric data"""
corrupted_data = {
"id": "not_a_number",
"job_name": "test_malformed_number",
"model_name": "test_model",
"role": "prefill",
"status": "active",
"endpoints": {},
"parallel_config": None,
"node_managers": [],
"gathered_workload": {"memory_mb": "1024", "cpu_cores": "2"},
}
persistent_state = PersistentState(
data={"invalid": corrupted_data},
version=1,
timestamp=time.time(),
checksum="",
)
persistent_state.checksum = persistent_state.calculate_checksum()
mock_persistent_states = {"state": persistent_state}
with patch('motor.controller.core.instance_manager.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
manager = create_instance_manager_with_config(enable_etcd=True)
result = manager.restore_data()
assert result
assert len(manager.instances) == 0