import pytest
import queue
from unittest.mock import Mock, patch, MagicMock
from motor.config.controller import ControllerConfig
from motor.controller.core.event_pusher import EventPusher, Event
from motor.common.resources.instance import Instance, ReadOnlyInstance
from motor.controller.core.observer import ObserverEvent
from motor.common.resources.http_msg_spec import EventType
@pytest.fixture
def event_pusher():
"""create EventPusher object fixture"""
with patch('threading.Thread') as mock_thread_class:
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
config = ControllerConfig()
return EventPusher(config)
@pytest.fixture
def mock_instance():
"""mock Instance fixture"""
instance = Mock(spec=Instance)
instance.job_name = "test_job"
return instance
@pytest.fixture
def mock_http_client():
"""mock HTTP client fixture"""
with patch('motor.controller.core.event_pusher.CoordinatorApiClient.send_instance_refresh') as mock_send_method:
mock_send_method.return_value = True
yield mock_send_method
def test_init(event_pusher):
"""init test case"""
assert not event_pusher.is_coordinator_reset
assert isinstance(event_pusher.event_queue, queue.Queue)
assert event_pusher.instances == {}
assert event_pusher.event_consumer_thread is None
assert event_pusher.heartbeat_detector_thread is None
def test_start():
"""test start method creates and starts threads"""
with patch('threading.Thread') as mock_thread_class:
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
config = ControllerConfig()
event_pusher = EventPusher(config)
assert event_pusher.event_consumer_thread is None
assert event_pusher.heartbeat_detector_thread is None
event_pusher.start()
assert event_pusher.event_consumer_thread is not None
assert event_pusher.heartbeat_detector_thread is not None
assert event_pusher.event_consumer_thread.daemon
assert event_pusher.heartbeat_detector_thread.daemon
mock_thread.start.assert_called()
assert mock_thread.start.call_count == 2
def test_event_consumer_add_event(event_pusher, mock_http_client):
"""test event consumer add event"""
test_instance = Instance(job_name="test_job", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances["test_job"] = readonly_instance
test_event = Event(event_type=EventType.ADD, instance=readonly_instance.to_instance())
event_pusher.event_queue.put(test_event)
event_pusher.event_queue.put(None)
_original_get = event_pusher.event_queue.get
def mock_get(timeout=None):
try:
return _original_get(block=False)
except queue.Empty:
raise StopIteration
with patch.object(event_pusher.event_queue, 'get', side_effect=mock_get):
try:
event_pusher._event_consumer()
except StopIteration:
pass
mock_http_client.assert_called_once()
def test_event_consumer_del_event(event_pusher, mock_http_client):
"""test event consumer del event"""
test_instance = Instance(job_name="test_job", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances["test_job"] = readonly_instance
test_event = Event(event_type=EventType.DEL, instance=readonly_instance.to_instance())
event_pusher.event_queue.put(test_event)
event_pusher.event_queue.put(None)
_original_get = event_pusher.event_queue.get
def mock_get(timeout=None):
try:
return _original_get(block=False)
except queue.Empty:
raise StopIteration
with patch.object(event_pusher.event_queue, 'get', side_effect=mock_get):
try:
event_pusher._event_consumer()
except StopIteration:
pass
mock_http_client.assert_called_once()
def test_event_consumer_set_event(event_pusher, mock_http_client):
"""test event consumer set event"""
for i in range(2):
job_name = "test_prefill_job" + str(i)
test_instance = Instance(job_name=job_name, model_name="test_model", id=i, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[job_name] = readonly_instance
job_name = "test_decode_job" + str(i)
test_instance = Instance(job_name=job_name, model_name="test_model", id=i + 10, role="decode")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[job_name] = readonly_instance
test_event = Event(event_type=EventType.SET, instance=None)
event_pusher.event_queue.put(test_event)
event_pusher.event_queue.put(None)
_original_get = event_pusher.event_queue.get
def mock_get(timeout=None):
try:
return _original_get(block=False)
except queue.Empty:
raise StopIteration
with patch.object(event_pusher.event_queue, 'get', side_effect=mock_get):
try:
event_pusher._event_consumer()
except StopIteration:
pass
mock_http_client.assert_called_once()
def test_event_consumer_set_event_skip_missing_prefill(event_pusher, mock_http_client):
"""test event consumer set event is skipped when missing prefill instance"""
for i in range(2):
job_name = "test_decode_job" + str(i)
test_instance = Instance(job_name=job_name, model_name="test_model", id=i, role="decode")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[job_name] = readonly_instance
test_event = Event(event_type=EventType.SET, instance=None)
event_pusher.event_queue.put(test_event)
event_pusher.event_queue.put(None)
_original_get = event_pusher.event_queue.get
def mock_get(timeout=None):
try:
return _original_get(block=False)
except queue.Empty:
raise StopIteration
with patch('motor.controller.core.event_pusher.logger') as mock_logger:
with patch.object(event_pusher.event_queue, 'get', side_effect=mock_get):
try:
event_pusher._event_consumer()
except StopIteration:
pass
mock_http_client.assert_not_called()
mock_logger.debug.assert_called_once_with(
"SET event skipped: requires at least one prefill instance, current instances: prefill=%s", False
)
def test_event_consumer_set_event_missing_decode(event_pusher, mock_http_client):
"""test event consumer set event is skipped when missing decode instance"""
for i in range(2):
job_name = "test_prefill_job" + str(i)
test_instance = Instance(job_name=job_name, model_name="test_model", id=i, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[job_name] = readonly_instance
test_event = Event(event_type=EventType.SET, instance=None)
event_pusher.event_queue.put(test_event)
event_pusher.event_queue.put(None)
_original_get = event_pusher.event_queue.get
def mock_get(timeout=None):
try:
return _original_get(block=False)
except queue.Empty:
raise StopIteration
with patch.object(event_pusher.event_queue, 'get', side_effect=mock_get):
try:
event_pusher._event_consumer()
except StopIteration:
pass
mock_http_client.assert_called()
def test_event_consumer_exception_handling(event_pusher, mock_http_client):
"""test event consumer exception handling"""
mock_http_client.return_value = False
test_instance = Instance(job_name="test_job", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances["test_job"] = readonly_instance
test_event = Event(event_type=EventType.ADD, instance=readonly_instance.to_instance())
event_pusher.event_queue.put(test_event)
event_pusher.event_queue.put(None)
_original_get = event_pusher.event_queue.get
def mock_get(timeout=None):
try:
return _original_get(block=False)
except queue.Empty:
raise StopIteration
with patch.object(event_pusher.event_queue, 'get', side_effect=mock_get):
try:
event_pusher._event_consumer()
except StopIteration:
pass
mock_http_client.assert_called_once()
def test_heartbeat_detector_normal(event_pusher):
"""test heartbeat detector"""
with patch('motor.controller.core.event_pusher.CoordinatorApiClient.query_status') as mock_query_status:
mock_query_status.return_value = {"ready": True}
event_pusher.is_coordinator_reset = True
call_count = 0
def mock_wait(timeout=None):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise StopIteration
with patch.object(event_pusher.work_condition, 'wait', side_effect=mock_wait):
try:
event_pusher._coordinator_heartbeat_detector()
except StopIteration:
pass
assert not event_pusher.is_coordinator_reset
assert not event_pusher.event_queue.empty()
evt = event_pusher.event_queue.get()
assert evt.event_type == EventType.SET
assert evt.instance is None
def test_heartbeat_detector_failure(event_pusher):
"""test heartbeat detector failure"""
call_count = 0
def mock_query_status(params: dict = None):
nonlocal call_count
call_count += 1
if call_count == 1:
event_pusher.is_first_heartbeat_success = True
return {"ready": True}
else:
raise RuntimeError("Connection failed")
with patch('motor.controller.core.event_pusher.CoordinatorApiClient.query_status', side_effect=mock_query_status):
sleep_count = 0
def mock_wait(timeout=None):
nonlocal sleep_count
sleep_count += 1
if sleep_count >= 5:
raise StopIteration
with patch('motor.controller.core.event_pusher.logger') as mock_logger:
with patch.object(event_pusher.work_condition, 'wait', side_effect=mock_wait):
try:
event_pusher._coordinator_heartbeat_detector()
except StopIteration:
pass
assert mock_logger.warning.call_count >= 1
warning_calls = [
call
for call in mock_logger.warning.call_args_list
if "Coordinator heartbeat lost. Possible restart detected" in str(call)
]
assert len(warning_calls) >= 1
def test_update_add_instance(event_pusher):
"""test update add instance"""
test_instance = Instance(job_name="test_job", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_READY)
assert readonly_instance.job_name in event_pusher.instances
assert event_pusher.instances[readonly_instance.job_name] == readonly_instance
assert not event_pusher.event_queue.empty()
event = event_pusher.event_queue.get()
assert event.event_type == EventType.ADD
assert event.instance.job_name == readonly_instance.job_name
def test_update_remove_instance(event_pusher):
"""test update remove instance"""
test_instance = Instance(job_name="test_job", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[readonly_instance.job_name] = readonly_instance
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_REMOVED)
assert not event_pusher.event_queue.empty()
event = event_pusher.event_queue.get()
assert event.event_type == EventType.DEL
assert event.instance.job_name == readonly_instance.job_name
def test_update_seperated_instance(event_pusher):
"""test update seperated instance"""
test_instance = Instance(job_name="test_job_seperated", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[readonly_instance.job_name] = readonly_instance
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_SEPERATED)
assert not event_pusher.event_queue.empty()
event = event_pusher.event_queue.get()
assert event.event_type == EventType.DEL
assert event.instance.job_name == readonly_instance.job_name
def test_update_seperated_instance_recovery(event_pusher):
"""test update seperated instance recovery"""
test_instance = Instance(job_name="test_job_recovery", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[readonly_instance.job_name] = readonly_instance
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_SEPERATED)
while not event_pusher.event_queue.empty():
event_pusher.event_queue.get()
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_READY)
assert not event_pusher.event_queue.empty()
event = event_pusher.event_queue.get()
assert event.event_type == EventType.ADD
assert event.instance.job_name == readonly_instance.job_name
def test_update_deep_copy_instance(event_pusher):
"""test that update method performs deep copy of instance for data consistency"""
original_job_name = "original_job"
original_model_name = "original_model"
test_instance = Instance(job_name=original_job_name, model_name=original_model_name, id=1, role="prefill")
from motor.common.resources.instance import NodeManagerInfo
test_instance.node_managers.append(NodeManagerInfo(pod_ip="192.168.1.1", host_ip="10.0.0.1", port="8080"))
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_READY)
assert not event_pusher.event_queue.empty()
event = event_pusher.event_queue.get()
assert event.event_type == EventType.ADD
assert event.instance is not readonly_instance
assert event.instance.job_name == original_job_name
assert event.instance.model_name == original_model_name
assert event.instance.node_managers is not test_instance.node_managers
assert len(event.instance.node_managers) == len(test_instance.node_managers)
assert event.instance.node_managers[0].pod_ip == test_instance.node_managers[0].pod_ip
test_instance.job_name = "modified_job"
test_instance.model_name = "modified_model"
test_instance.node_managers[0].pod_ip = "192.168.1.2"
assert event.instance.job_name == original_job_name
assert event.instance.model_name == original_model_name
assert event.instance.node_managers[0].pod_ip == "192.168.1.1"
assert event_pusher.instances[original_job_name] is readonly_instance
def test_update_deep_copy_seperated_instance(event_pusher):
"""test that update method performs deep copy for seperated instance events"""
original_job_name = "seperated_job"
test_instance = Instance(job_name=original_job_name, model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.instances[readonly_instance.job_name] = readonly_instance
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_SEPERATED)
assert not event_pusher.event_queue.empty()
event = event_pusher.event_queue.get()
assert event.event_type == EventType.DEL
assert event.instance is not readonly_instance
assert event.instance.job_name == original_job_name
test_instance.job_name = "modified_seperated_job"
assert event.instance.job_name == original_job_name
def test_update_seperated_instance_initial_stage_abnormal(event_pusher):
"""test update seperated instance when instance abnormal in initial stage"""
test_instance = Instance(job_name="test_job_initial_abnormal", model_name="test_model", id=1, role="prefill")
readonly_instance = ReadOnlyInstance(test_instance)
event_pusher.update(readonly_instance, ObserverEvent.INSTANCE_SEPERATED)
assert event_pusher.event_queue.empty()
assert readonly_instance.job_name not in event_pusher.instances
def test_update_config():
"""Test update_config method updates configuration"""
with patch('threading.Thread') as mock_thread_class:
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
config = ControllerConfig()
event_pusher = EventPusher(config)
original_coordinator_heartbeat_interval = event_pusher.coordinator_heartbeat_interval
new_config = ControllerConfig()
new_config.event_config.coordinator_heartbeat_interval = 20.0
event_pusher.update_config(new_config)
assert event_pusher.coordinator_heartbeat_interval == 20.0
assert event_pusher.coordinator_heartbeat_interval != original_coordinator_heartbeat_interval