import time
import hashlib
import pytest
from unittest.mock import MagicMock, patch
from motor.common.resources import Instance, InsStatus, ParallelConfig, Endpoint, ReadOnlyInstance
from motor.common.resources.http_msg_spec import RegisterMsg, ReregisterMsg, Ranktable, ServerInfo, DeviceInfo
from motor.controller.core.instance_assembler import (
InstanceAssembler,
AssembleInstanceMetadata,
RegisterStatus,
)
from motor.common.etcd.persistent_state import PersistentState
from motor.common.utils.singleton import ThreadSafeSingleton
from motor.controller.core import InstanceManager
from motor.config.controller import ControllerConfig
def build_pod_ranktable(
pod_ip: str,
pod_device_num: int,
rank_offset: int = 0,
is_supperpod: bool = True,
) -> Ranktable:
"""
Build pod level ranktable, it only have on server, so server_list size is 1.
This function is mainly for test case to build ranktable.
"""
ranktable = Ranktable(
version="1.2",
status="completed",
server_count="1",
server_list=[
ServerInfo(
server_id=pod_ip,
container_ip=pod_ip,
device=[
DeviceInfo(
device_ip=pod_ip,
device_id=str(i),
rank_id=str(rank_offset + i),
super_device_id="0" if is_supperpod else None,
)
for i in range(pod_device_num)
],
)
],
)
return ranktable
@pytest.fixture
def test_config():
"""Test configuration fixture"""
dp = 4
tp = 2
role = "prefill"
pod_ip1 = "127.0.0.1"
pod_ip2 = "127.0.0.2"
parallel_config = ParallelConfig(dp_size=dp, tp_size=tp, local_world_size=tp, world_size=dp * tp)
return {
'dp': dp,
'tp': tp,
'role': role,
'pod_ip1': pod_ip1,
'pod_ip2': pod_ip2,
'parallel_config': parallel_config,
}
def _cleanup_singletons():
"""Clean up singleton instances to ensure test isolation"""
singletons_to_cleanup = [InstanceAssembler, InstanceManager]
for singleton_cls in singletons_to_cleanup:
if singleton_cls in ThreadSafeSingleton._instances:
instance = ThreadSafeSingleton._instances[singleton_cls]
try:
if hasattr(instance, 'stop'):
instance.stop()
except Exception:
pass
del ThreadSafeSingleton._instances[singleton_cls]
@pytest.fixture(autouse=True)
def cleanup_singletons():
"""Auto cleanup singletons before and after each test"""
_cleanup_singletons()
yield
_cleanup_singletons()
@pytest.fixture
def mock_config():
"""Mock controller config"""
config = ControllerConfig()
config.etcd_config.enable_etcd_persistence = False
config.instance_config.instance_assemble_timeout = 1.0
config.instance_config.instance_assembler_check_interval = 0.1
config.instance_config.instance_assembler_cmd_send_interval = 0.1
config.instance_config.send_cmd_retry_times = 3
return config
@pytest.fixture
def instance_assembler(mock_config):
"""Setup mock assembler with threading mocked to prevent actual thread starts"""
with patch('threading.Thread') as mock_thread_class:
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
with patch('motor.controller.core.instance_assembler.EtcdClient') as mock_etcd_class:
mock_etcd = MagicMock()
mock_etcd_class.return_value = mock_etcd
assembler = InstanceAssembler(mock_config)
yield assembler
def create_register_msg(job_name: str, pod_ip: str, config: dict, **kwargs) -> RegisterMsg:
"""Create a RegisterMsg with common defaults"""
defaults = {
'model_name': "test_model",
'role': config['role'],
'business_port': ["8080", "8084"],
'mgmt_port': ["9090", "9094"],
'nm_port': "8088",
'parallel_config': config['parallel_config'],
'enable_multi_endpoints': True,
'device_num': 2 * config['tp'],
}
defaults.update(kwargs)
return RegisterMsg(job_name=job_name, pod_ip=pod_ip, **defaults)
def create_reregister_msg(job_name: str, pod_ip: str, instance_id: int, config: dict, endpoints: list) -> ReregisterMsg:
"""Create a ReregisterMsg with common defaults"""
if isinstance(endpoints, dict):
endpoints_list = list(endpoints.values())
else:
endpoints_list = endpoints
return ReregisterMsg(
job_name=job_name,
model_name="test_model",
instance_id=instance_id,
role=config['role'],
pod_ip=pod_ip,
nm_port="8088",
parallel_config=config['parallel_config'],
endpoints=endpoints_list,
enable_multi_endpoints=True,
)
def register_instance_with_pods(assembler: InstanceAssembler, job_name: str, config: dict, pod_count: int = 2) -> bool:
"""Register pods for an instance and return whether assembly is complete"""
pod_ips = [f"127.0.0.{i + 1}" for i in range(pod_count)]
for i, pod_ip in enumerate(pod_ips):
rank_offset = i * 2 * config['tp']
msg = create_register_msg(
job_name,
pod_ip,
config,
ranktable=build_pod_ranktable(pod_ip=pod_ip, pod_device_num=2 * config['tp'], rank_offset=rank_offset),
)
result = assembler.register(msg)
assert result == 0
if job_name in assembler.instances:
metadata = assembler.instances[job_name]
with patch.object(assembler, '_filter_abnormal_endpoints'):
assembler._assemble_instance(metadata)
return metadata.register_status == RegisterStatus.ASSEMBLED
return False
def create_assembled_instance(assembler: InstanceAssembler, job_name: str, config: dict) -> AssembleInstanceMetadata:
"""Create and assemble a complete instance"""
success = register_instance_with_pods(assembler, job_name, config)
assert success, f"Failed to assemble instance {job_name}"
return assembler.instances[job_name]
def test_initialization(mock_config):
"""Test InstanceAssembler initialization"""
with patch('threading.Thread'):
with patch('motor.controller.core.instance_assembler.EtcdClient'):
assembler = InstanceAssembler(mock_config)
assert assembler.etcd_config is mock_config.etcd_config
assert assembler.ins_id_cnt == 1
assert len(assembler.instances) == 0
assert not assembler.stop_event.is_set()
assert assembler._data_version == 0
def test_singleton_behavior(mock_config):
"""Test singleton pattern prevents re-initialization"""
with patch('threading.Thread'), patch('motor.controller.core.instance_assembler.EtcdClient'):
assembler1 = InstanceAssembler(mock_config)
original_timeout = assembler1.instance_assemble_timeout
different_config = ControllerConfig()
different_config.instance_config.instance_assemble_timeout = 999
assembler2 = InstanceAssembler(different_config)
assert assembler1 is assembler2
assert assembler1.instance_assemble_timeout == original_timeout
def test_init_with_none_config():
"""Test initialization with None config uses default"""
with patch('threading.Thread'), patch('motor.controller.core.instance_assembler.EtcdClient'):
assembler = InstanceAssembler(config=None)
assert assembler.instance_assemble_timeout is not None
assert hasattr(assembler, 'instance_assemble_timeout')
def test_register_new_instance(instance_assembler, test_config):
"""Test registering a new instance"""
job_name = "test_job"
msg = create_register_msg(job_name, test_config['pod_ip1'], test_config)
result = instance_assembler.register(msg)
assert result == 0
assert job_name in instance_assembler.instances
metadata = instance_assembler.instances[job_name]
assert metadata.register_status == RegisterStatus.NOT_REGISTERED
assert metadata.instance.job_name == job_name
assert metadata.instance.id == 1
assert instance_assembler.ins_id_cnt == 2
assert len(metadata.instance.endpoints) == 1
assert len(metadata.instance.node_managers) == 1
def test_register_existing_instance(instance_assembler, test_config):
"""Test registering additional pods to existing instance"""
job_name = "test_job"
msg1 = create_register_msg(job_name, test_config['pod_ip1'], test_config)
result1 = instance_assembler.register(msg1)
assert result1 == 0
assert len(instance_assembler.instances) == 1
msg2 = create_register_msg(
job_name,
test_config['pod_ip2'],
test_config,
ranktable=build_pod_ranktable(
pod_ip=test_config['pod_ip2'], pod_device_num=2 * test_config['tp'], rank_offset=2 * test_config['tp']
),
)
result2 = instance_assembler.register(msg2)
assert result2 == 0
assert len(instance_assembler.instances) == 1
metadata = instance_assembler.instances[job_name]
assert len(metadata.instance.endpoints) == 2
def test_register_already_assembled_instance(instance_assembler, test_config):
"""Test registering to an already assembled instance returns -1"""
job_name = "test_job"
metadata = create_assembled_instance(instance_assembler, job_name, test_config)
assert job_name in instance_assembler.instances
assert metadata.register_status == RegisterStatus.ASSEMBLED
def stop_sleep(*args, **kwargs):
raise RuntimeError("Stop iteration")
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command', return_value=True
):
with patch.object(instance_assembler.work_condition, 'wait', side_effect=stop_sleep):
try:
instance_assembler._start_commmand_sender()
except RuntimeError as e:
if "Stop iteration" not in str(e):
raise
assert job_name not in instance_assembler.instances
with patch.object(InstanceManager(), 'has_active_instance_by_job_name', return_value=True):
msg = create_register_msg(job_name, "127.0.0.3", test_config)
result = instance_assembler.register(msg)
assert result == -1
def test_reregister_new_instance(instance_assembler, test_config):
"""Test reregistering a new instance"""
job_name = "test_reregister"
reg_msg = create_register_msg(job_name, test_config['pod_ip1'], test_config)
endpoints = instance_assembler._build_single_endpoint(reg_msg, 0)
msg = create_reregister_msg(
job_name, test_config['pod_ip1'], instance_id=5, config=test_config, endpoints=endpoints
)
result = instance_assembler.reregister(msg)
assert result == 0
assert job_name in instance_assembler.instances
metadata = instance_assembler.instances[job_name]
assert metadata.register_status == RegisterStatus.NOT_REGISTERED
assert metadata.is_reregister is True
assert metadata.instance.id == 5
assert instance_assembler.ins_id_cnt == 6
def test_reregister_already_assembled_instance(instance_assembler, test_config):
"""Test reregistering to an already assembled instance returns -1"""
job_name = "test_reregister"
reg_msg = create_register_msg(job_name, test_config['pod_ip1'], test_config)
endpoints = instance_assembler._build_multi_endpoints(reg_msg, 0)
msg = create_reregister_msg(
job_name, test_config['pod_ip1'], instance_id=0, config=test_config, endpoints=endpoints
)
result = instance_assembler.reregister(msg)
assert result == 0
reg_msg2 = create_register_msg(
job_name,
test_config['pod_ip2'],
test_config,
ranktable=build_pod_ranktable(
pod_ip=test_config['pod_ip2'], pod_device_num=2 * test_config['tp'], rank_offset=2 * test_config['tp']
),
)
endpoints2 = instance_assembler._build_multi_endpoints(reg_msg2, 2)
msg2 = create_reregister_msg(
job_name, test_config['pod_ip2'], instance_id=0, config=test_config, endpoints=endpoints2
)
result2 = instance_assembler.reregister(msg2)
assert result2 == 0
metadata = instance_assembler.instances[job_name]
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert job_name not in instance_assembler.instances
with patch.object(InstanceManager(), 'has_active_instance_by_job_name', return_value=True):
msg3 = create_reregister_msg(job_name, "127.0.0.3", instance_id=0, config=test_config, endpoints=endpoints)
result3 = instance_assembler.reregister(msg3)
assert result3 == -1
def test_eval_register_status(instance_assembler, test_config):
"""Test _eval_register_status for different scenarios"""
job_name_new = "test_new"
job_name_assembling = "test_assembling"
job_name_assembled = "test_assembled"
status = instance_assembler._eval_register_status(job_name_new)
assert status == RegisterStatus.NOT_REGISTERED
msg = create_register_msg(job_name_assembling, test_config['pod_ip1'], test_config)
instance_assembler.register(msg)
status = instance_assembler._eval_register_status(job_name_assembling)
assert status == RegisterStatus.ASSEMBLING
with patch.object(InstanceManager(), 'has_active_instance_by_job_name', return_value=True):
status = instance_assembler._eval_register_status(job_name_assembled)
assert status == RegisterStatus.ASSEMBLED
def test_assembly_incomplete_instance(instance_assembler, test_config):
"""Test assembly of incomplete instance (not enough endpoints)"""
job_name = "test_incomplete"
msg = create_register_msg(job_name, test_config['pod_ip1'], test_config, business_port=["8080"])
instance_assembler.register(msg)
metadata = instance_assembler.instances[job_name]
original_status = metadata.register_status
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == original_status
assert job_name in instance_assembler.instances
def test_assembly_complete_instance_new_registration(instance_assembler, test_config):
"""Test assembly of complete instance (new registration)"""
job_name = "test_complete_new"
metadata = create_assembled_instance(instance_assembler, job_name, test_config)
assert metadata.register_status == RegisterStatus.ASSEMBLED
assert job_name in instance_assembler.instances
instance_manager = InstanceManager()
assert instance_manager.has_instance_by_job_name(job_name)
def test_assembly_complete_instance_reregistration(instance_assembler, test_config):
"""Test assembly of complete instance (reregistration)"""
job_name = "test_complete_reregister"
reg_msg1 = create_register_msg(job_name, test_config['pod_ip1'], test_config)
reg_msg2 = create_register_msg(
job_name,
test_config['pod_ip2'],
test_config,
ranktable=build_pod_ranktable(
pod_ip=test_config['pod_ip2'], pod_device_num=2 * test_config['tp'], rank_offset=2 * test_config['tp']
),
)
endpoints1 = instance_assembler._build_multi_endpoints(reg_msg1, 0)
endpoints2 = instance_assembler._build_multi_endpoints(reg_msg2, 2)
msg1 = create_reregister_msg(job_name, test_config['pod_ip1'], 0, config=test_config, endpoints=endpoints1)
msg2 = create_reregister_msg(job_name, test_config['pod_ip2'], 0, config=test_config, endpoints=endpoints2)
instance_assembler.reregister(msg1)
instance_assembler.reregister(msg2)
metadata = instance_assembler.instances[job_name]
assert metadata.is_reregister is True
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert job_name not in instance_assembler.instances
instance_manager = InstanceManager()
assert instance_manager.has_instance_by_job_name(job_name)
@patch('motor.controller.core.instance_assembler.NodeManagerApiClient.query_status')
def test_assembly_timeout(mock_query_status, instance_assembler, test_config):
"""Test instance assembly timeout"""
job_name = "test_timeout"
mock_query_status.return_value = {"status": True}
instance_assembler.instance_assemble_timeout = 0.05
msg = create_register_msg(job_name, test_config['pod_ip1'], test_config, business_port=["8080"])
instance_assembler.register(msg)
time.sleep(0.06)
metadata = instance_assembler.instances[job_name]
instance_assembler._assemble_instance(metadata)
assert job_name not in instance_assembler.instances
def test_send_start_command_success(instance_assembler, test_config):
"""Test successful start command sending"""
job_name = "test_start_success"
metadata = create_assembled_instance(instance_assembler, job_name, test_config)
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send:
mock_send.return_value = True
result = instance_assembler._send_start_command(metadata)
assert result is True
assert mock_send.call_count == len(metadata.instance.node_managers)
def test_send_start_command_partial_failure(instance_assembler, test_config):
"""Test start command with partial failure"""
job_name = "test_start_partial_failure"
metadata = create_assembled_instance(instance_assembler, job_name, test_config)
call_count = 0
def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
return call_count == 1
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send:
mock_send.side_effect = side_effect
result = instance_assembler._send_start_command(metadata)
assert result is False
assert mock_send.call_count == len(metadata.instance.node_managers)
def test_send_start_command_no_endpoints(instance_assembler, test_config):
"""Test start command when some node managers have no endpoints"""
instance = Instance(
job_name="test_no_endpoints",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
instance.add_node_mgr("127.0.0.2", "8089")
reg_msg = create_register_msg("test", "127.0.0.1", test_config)
pod_endpoints = instance_assembler._build_single_endpoint(reg_msg, 0)
instance.add_endpoints("127.0.0.1", pod_endpoints)
metadata = AssembleInstanceMetadata(instance=instance)
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send:
mock_send.return_value = True
result = instance_assembler._send_start_command(metadata)
assert result is True
assert mock_send.call_count == 1
def test_start_command_sender_success(instance_assembler, test_config):
"""Test _start_command_sender removes instance after successful start"""
job_name = "test_sender_success"
create_assembled_instance(instance_assembler, job_name, test_config)
def stop_sleep(*args, **kwargs):
raise RuntimeError("Stop iteration")
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send:
mock_send.return_value = True
with patch.object(instance_assembler.work_condition, 'wait', side_effect=stop_sleep):
try:
instance_assembler._start_commmand_sender()
except RuntimeError as e:
if "Stop iteration" not in str(e):
raise
assert job_name not in instance_assembler.instances
def test_start_command_sender_retry(instance_assembler, test_config):
"""Test _start_command_sender retries on failure"""
job_name = "test_sender_retry"
create_assembled_instance(instance_assembler, job_name, test_config)
def stop_sleep(*args, **kwargs):
raise RuntimeError("Stop iteration")
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send:
mock_send.return_value = False
with patch.object(instance_assembler.work_condition, 'wait', side_effect=stop_sleep):
try:
instance_assembler._start_commmand_sender()
except RuntimeError as e:
if "Stop iteration" not in str(e):
raise
assert job_name in instance_assembler.instances
assert instance_assembler.instances[job_name].start_command_send_times == 1
def test_start_command_sender_max_retries(instance_assembler, test_config):
"""Test _start_command_sender removes instance after max retries"""
job_name = "test_sender_max_retries"
instance_assembler.send_cmd_retry_times = 2
create_assembled_instance(instance_assembler, job_name, test_config)
def stop_sleep(*args, **kwargs):
raise RuntimeError("Stop iteration")
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send:
mock_send.return_value = False
with patch.object(instance_assembler.work_condition, 'wait', side_effect=stop_sleep):
try:
instance_assembler._start_commmand_sender()
except RuntimeError as e:
if "Stop iteration" not in str(e):
raise
assert job_name in instance_assembler.instances
assert instance_assembler.instances[job_name].start_command_send_times == 1
with patch.object(instance_assembler.work_condition, 'wait', side_effect=stop_sleep):
try:
instance_assembler._start_commmand_sender()
except RuntimeError as e:
if "Stop iteration" not in str(e):
raise
assert job_name not in instance_assembler.instances
def test_persist_data_disabled(mock_config):
"""Test persist_data when ETCD persistence is disabled"""
mock_config.etcd_config.enable_etcd_persistence = False
with patch('threading.Thread'), patch('motor.controller.core.instance_assembler.EtcdClient') as mock_etcd_class:
mock_etcd = MagicMock()
mock_etcd.persist_data.return_value = True
mock_etcd_class.return_value = mock_etcd
assembler = InstanceAssembler(mock_config)
result = assembler.persist_data()
assert result is True
def test_persist_data_enabled(instance_assembler, test_config):
"""Test persist_data when ETCD persistence is enabled"""
create_assembled_instance(instance_assembler, "test_job", test_config)
instance_assembler.etcd_client.persist_data.reset_mock()
instance_assembler.persist_data()
instance_assembler.etcd_client.persist_data.assert_called_once()
args, kwargs = instance_assembler.etcd_client.persist_data.call_args
assert "/controller/instance_assembler" in args[0]
assert "state" in args[1]
def test_restore_data_disabled(instance_assembler, test_config):
"""Test restore_data when ETCD persistence is disabled"""
instance_assembler.etcd_config.enable_etcd_persistence = False
with patch.object(instance_assembler.etcd_client, 'restore_data', return_value=None):
result = instance_assembler.restore_data()
assert result is True
def test_restore_data_enabled(instance_assembler, test_config):
"""Test restore_data when ETCD persistence is enabled"""
state = PersistentState(
data={"ins_id_cnt": 5, "instances": {}},
version=1,
timestamp=time.time(),
checksum="",
)
state.checksum = state.calculate_checksum()
mock_persistent_states = {"state": state}
with patch.object(instance_assembler.etcd_client, 'restore_data', return_value=mock_persistent_states):
result = instance_assembler.restore_data()
assert result is True
assert instance_assembler.ins_id_cnt == 5
def test_checksum_calculation(instance_assembler, test_config):
"""Test checksum calculation for data integrity"""
metadata = create_assembled_instance(instance_assembler, "test_checksum", test_config)
metadata_data = metadata.model_dump(mode='json')
state = PersistentState(data=metadata_data, version=1, timestamp=time.time(), checksum="")
checksum = state.calculate_checksum()
assert isinstance(checksum, str)
assert len(checksum) > 0
checksum2 = state.calculate_checksum()
assert checksum == checksum2
def test_ins_id_cnt_checksum(instance_assembler):
"""Test checksum calculation for ins_id_cnt"""
instance_assembler.ins_id_cnt = 42
ins_id_cnt_data = {"ins_id_cnt": instance_assembler.ins_id_cnt}
state = PersistentState(data=ins_id_cnt_data, version=1, timestamp=time.time(), checksum="")
checksum = state.calculate_checksum()
assert isinstance(checksum, str)
assert len(checksum) > 0
checksum2 = state.calculate_checksum()
assert checksum == checksum2
def test_persist_data_exception_handling(instance_assembler, test_config):
"""Test persist_data exception handling"""
create_assembled_instance(instance_assembler, "test_persist_exception", test_config)
with patch.object(instance_assembler.etcd_client, 'persist_data', side_effect=Exception("ETCD connection failed")):
result = instance_assembler.persist_data()
assert result is False
def test_restore_data_exception_handling(instance_assembler):
"""Test restore_data exception handling"""
with patch.object(instance_assembler.etcd_client, 'restore_data', side_effect=Exception("ETCD connection failed")):
result = instance_assembler.restore_data()
assert result is False
def test_restore_data_invalid_checksum(instance_assembler):
"""Test restore_data with invalid checksum (corrupted data)"""
mock_persistent_states = {
"state": PersistentState(
data={"ins_id_cnt": 5, "instances": {}},
version=1,
timestamp=time.time(),
checksum="invalid_checksum",
)
}
with patch.object(instance_assembler.etcd_client, 'restore_data', return_value=mock_persistent_states):
result = instance_assembler.restore_data()
assert result is False
assert instance_assembler.ins_id_cnt == 1
def test_restore_data_reconstruction_exception(instance_assembler):
"""Test restore_data with reconstruction exception"""
with patch('motor.controller.core.instance_assembler.AssembleInstanceMetadata.model_validate') as mock_validate:
mock_validate.side_effect = Exception("Metadata validation failed")
metadata_data = {
"instance": {
"job_name": "test_instance",
"model_name": "test_model",
"id": 0,
"role": "prefill",
"parallel_config": {
"dp_size": 1,
"pcp_size": 1,
"tp_size": 1,
"ep_size": 1,
"pp_size": 1,
"world_size": 1,
},
"endpoints": {},
"node_managers": [],
},
"register_status": "NOT_REGISTERED",
"start_command_send_times": 0,
"register_timestamp": time.time(),
"is_reregister": False,
}
state = PersistentState(
data={"ins_id_cnt": 1, "instances": {"test_instance": metadata_data}},
version=1,
timestamp=time.time(),
checksum="",
)
state.checksum = state.calculate_checksum()
mock_persistent_states = {"state": state}
with patch.object(instance_assembler.etcd_client, 'restore_data', return_value=mock_persistent_states):
result = instance_assembler.restore_data()
assert result is True
assert len(instance_assembler.instances) == 0
def test_checksum_calculation_exception_handling(instance_assembler, test_config):
"""Test checksum calculation exception handling"""
metadata = create_assembled_instance(instance_assembler, "test_checksum_exception", test_config)
metadata_data = metadata.model_dump(mode='json')
state = PersistentState(data=metadata_data, version=1, timestamp=time.time(), checksum="")
with patch.object(hashlib, 'sha256', side_effect=Exception("Hash calculation failed")):
checksum = state.calculate_checksum()
assert checksum == ""
def test_ins_id_cnt_checksum_exception_handling(instance_assembler):
"""Test ins_id_cnt checksum calculation exception handling"""
instance_assembler.ins_id_cnt = 42
ins_id_cnt_data = {"ins_id_cnt": instance_assembler.ins_id_cnt}
state = PersistentState(data=ins_id_cnt_data, version=1, timestamp=time.time(), checksum="")
with patch.object(hashlib, 'sha256', side_effect=Exception("Hash calculation failed")):
checksum = state.calculate_checksum()
assert checksum == ""
def test_persistent_state_is_valid_method():
"""Test PersistentState.is_valid method"""
valid_state = PersistentState(
data={"test": "data"},
version=1,
timestamp=time.time(),
checksum="",
)
valid_state.checksum = valid_state.calculate_checksum()
assert valid_state.is_valid() is True
invalid_state = PersistentState(data={"test": "data"}, version=1, timestamp=time.time(), checksum="wrong_checksum")
assert invalid_state.is_valid() is False
def test_restore_data_with_type_conversion():
"""Test restoration with string-formatted data from ETCD (type conversion)"""
etcd_string_metadata = {
"instance": {
"job_name": "test_type_conversion",
"model_name": "test_model",
"id": "208",
"role": "prefill",
"parallel_config": None,
"endpoints": {},
"node_managers": [],
},
"register_status": "ASSEMBLED",
"start_command_send_times": "0",
"register_timestamp": str(time.time()),
"is_reregister": "False",
}
persistent_state = PersistentState(
data={"ins_id_cnt": 1, "instances": {"test_type_conversion": etcd_string_metadata}},
version=1,
timestamp=time.time(),
checksum="",
)
persistent_state.checksum = persistent_state.calculate_checksum()
mock_persistent_states = {"state": persistent_state}
with patch('motor.controller.core.instance_assembler.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
config = ControllerConfig()
config.etcd_config.enable_etcd_persistence = True
with patch('threading.Thread'):
assembler = InstanceAssembler(config)
result = assembler.restore_data()
assert result is True
assert "test_type_conversion" in assembler.instances
metadata = assembler.instances["test_type_conversion"]
assert metadata.instance.id == 208
assert metadata.register_status == RegisterStatus.ASSEMBLED
assert metadata.start_command_send_times == 0
assert metadata.is_reregister is False
def test_restore_data_with_invalid_enum_value():
"""Test restoration fails gracefully with invalid enum values in metadata"""
corrupted_metadata = {
"instance": {
"job_name": "test_invalid_enum",
"model_name": "test_model",
"id": "209",
"role": "prefill",
"parallel_config": None,
"endpoints": {},
"node_managers": [],
},
"register_status": "999",
"start_command_send_times": "0",
"register_timestamp": str(time.time()),
"is_reregister": "False",
}
persistent_state = PersistentState(
data={"ins_id_cnt": 1, "instances": {"test_invalid_enum": corrupted_metadata}},
version=1,
timestamp=time.time(),
checksum="",
)
persistent_state.checksum = persistent_state.calculate_checksum()
mock_persistent_states = {"state": persistent_state}
with patch('motor.controller.core.instance_assembler.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_client.restore_data.return_value = mock_persistent_states
mock_etcd_class.return_value = mock_client
config = ControllerConfig()
config.etcd_config.enable_etcd_persistence = True
with patch('threading.Thread'):
assembler = InstanceAssembler(config)
result = assembler.restore_data()
assert result is True
assert "test_invalid_enum" not in assembler.instances
def test_start_method(mock_config):
"""Test start method starts threads"""
with patch('threading.Thread') as mock_thread_class:
with patch('motor.controller.core.instance_assembler.EtcdClient'):
assembler = InstanceAssembler(mock_config)
assembler.start()
assert mock_thread_class.call_count == 2
assert mock_thread_class.return_value.start.call_count == 2
def test_stop_method(mock_config):
"""Test stop method sets stop event and joins threads"""
with patch('threading.Thread') as mock_thread_class:
with patch('motor.controller.core.instance_assembler.EtcdClient'):
mock_thread1 = MagicMock()
mock_thread2 = MagicMock()
mock_thread1.is_alive.return_value = True
mock_thread2.is_alive.return_value = True
mock_thread_class.side_effect = [mock_thread1, mock_thread2]
assembler = InstanceAssembler(mock_config)
assembler.start()
assembler.stop()
assert assembler.stop_event.is_set()
mock_thread1.join.assert_called_once()
mock_thread2.join.assert_called_once()
def test_instances_assembler_loop_stop_event(instance_assembler, test_config):
"""Test _instances_assembler_loop respects stop event"""
instance_assembler.stop_event.set()
def stop_sleep(*args, **kwargs):
raise RuntimeError("Stop iteration")
with patch.object(instance_assembler.work_condition, 'wait', side_effect=stop_sleep):
try:
instance_assembler._instances_assembler_loop()
except RuntimeError as e:
if "Stop iteration" not in str(e):
raise
def test_multiple_instances_registration(instance_assembler, test_config):
"""Test registering multiple instances"""
num_instances = 5
for i in range(num_instances):
job_name = f"perf_test_{i}"
success = register_instance_with_pods(instance_assembler, job_name, test_config)
assert success
assert len(instance_assembler.instances) == num_instances
ids = [metadata.instance.id for metadata in instance_assembler.instances.values()]
assert len(set(ids)) == num_instances
def test_ins_id_cnt_increment(instance_assembler, test_config):
"""Test ins_id_cnt increments correctly"""
initial_cnt = instance_assembler.ins_id_cnt
register_instance_with_pods(instance_assembler, "job1", test_config)
assert instance_assembler.ins_id_cnt == initial_cnt + 1
register_instance_with_pods(instance_assembler, "job2", test_config)
assert instance_assembler.ins_id_cnt == initial_cnt + 2
def test_update_config(instance_assembler):
"""Test update_config method updates configuration and recreates ETCD client"""
new_config = ControllerConfig()
new_config.etcd_config.etcd_host = "new-etcd-host"
new_config.etcd_config.etcd_port = 2380
new_config.etcd_config.etcd_timeout = 30.0
new_config.etcd_config.enable_etcd_persistence = True
with patch('motor.controller.core.instance_assembler.EtcdClient') as mock_etcd_class:
mock_client = MagicMock()
mock_etcd_class.return_value = mock_client
mock_etcd_class.reset_mock()
instance_assembler.update_config(new_config)
assert instance_assembler.etcd_config is new_config.etcd_config
assert instance_assembler.etcd_config.etcd_host == "new-etcd-host"
assert instance_assembler.etcd_config.etcd_port == 2380
assert instance_assembler.etcd_config.etcd_timeout == 30.0
mock_etcd_class.assert_called_once_with(
etcd_config=new_config.etcd_config, tls_config=new_config.etcd_tls_config
)
def test_persist_and_restore_data_success(instance_assembler, test_config):
"""Test successful persist and restore of instance assembler data"""
metadata = create_assembled_instance(instance_assembler, "test_persist", test_config)
instance_assembler.etcd_config.enable_etcd_persistence = True
with patch.object(instance_assembler.etcd_client, 'persist_data', return_value=True) as mock_persist:
with patch.object(instance_assembler.etcd_client, 'restore_data') as mock_restore:
persist_result = instance_assembler.persist_data()
assert persist_result is True
mock_persist.assert_called_once()
args, kwargs = mock_persist.call_args
assert "/controller/instance_assembler" in args[0]
metadata_data = metadata.model_dump(mode='json')
assembler_data = {"ins_id_cnt": instance_assembler.ins_id_cnt, "instances": {"test_persist": metadata_data}}
assembler_state = PersistentState(data=assembler_data, version=1, timestamp=time.time(), checksum="")
assembler_state.checksum = assembler_state.calculate_checksum()
mock_persistent_states = {"state": assembler_state}
mock_restore.return_value = mock_persistent_states
with patch('threading.Thread'), patch('motor.controller.core.instance_assembler.EtcdClient'):
new_config = ControllerConfig()
new_config.etcd_config.enable_etcd_persistence = True
new_config.instance_config.instance_assemble_timeout = 1.0
new_config.instance_config.instance_assembler_check_interval = 0.1
new_config.instance_config.instance_assembler_cmd_send_interval = 0.1
new_config.instance_config.send_cmd_retry_times = 3
new_assembler = InstanceAssembler(new_config)
restore_result = new_assembler.restore_data()
assert restore_result is True
assert new_assembler.ins_id_cnt == instance_assembler.ins_id_cnt
assert "test_persist" in new_assembler.instances
restored_metadata = new_assembler.instances["test_persist"]
assert restored_metadata.instance.job_name == metadata.instance.job_name
assert restored_metadata.register_status == metadata.register_status
def test_persist_data_with_checksum_validation(instance_assembler, test_config):
"""Test that persisted data includes correct checksums"""
create_assembled_instance(instance_assembler, "test_checksum", test_config)
instance_assembler.etcd_config.enable_etcd_persistence = True
with patch.object(instance_assembler.etcd_client, 'persist_data', return_value=True) as mock_persist:
result = instance_assembler.persist_data()
assert result is True
args, kwargs = mock_persist.call_args
persisted_data = args[1]
assert "state" in persisted_data
state_data = persisted_data["state"]
assert "checksum" in state_data
assert len(state_data["checksum"]) > 0
assert "data" in state_data
assert "ins_id_cnt" in state_data["data"]
assert "instances" in state_data["data"]
assert "test_checksum" in state_data["data"]["instances"]
state = PersistentState(**state_data)
assert state.is_valid()
def test_restore_data_with_invalid_checksum(instance_assembler, test_config):
"""Test restore skips data with invalid checksums"""
mock_persistent_states = {
"state": PersistentState(
data={"ins_id_cnt": 5, "instances": {}},
version=1,
timestamp=time.time(),
checksum="invalid_checksum",
)
}
with patch.object(instance_assembler.etcd_client, 'restore_data', return_value=mock_persistent_states):
result = instance_assembler.restore_data()
assert result is False
assert instance_assembler.ins_id_cnt == 1
def test_persistence_disabled_in_config(instance_assembler, test_config):
"""Test that persistence is properly disabled when config flag is False"""
instance_assembler.etcd_config.enable_etcd_persistence = False
create_assembled_instance(instance_assembler, "test_disabled", test_config)
with patch.object(instance_assembler.etcd_client, 'persist_data', return_value=True) as mock_persist:
result = instance_assembler.persist_data()
assert result is True
msg = create_register_msg("test_register_disabled", test_config['pod_ip1'], test_config)
instance_assembler.register(msg)
assert mock_persist.call_count == 1
def test_persist_empty_state(instance_assembler):
"""Test persisting when no instances exist"""
instance_assembler.etcd_config.enable_etcd_persistence = True
with patch.object(instance_assembler.etcd_client, 'persist_data', return_value=True) as mock_persist:
result = instance_assembler.persist_data()
assert result is True
args, kwargs = mock_persist.call_args
persisted_data = args[1]
assert "state" in persisted_data
assembler_data = persisted_data["state"]
assert assembler_data["data"]["ins_id_cnt"] == instance_assembler.ins_id_cnt
assert len(assembler_data["data"]["instances"]) == 0
assert assembler_data["version"] >= 1
assert assembler_data["timestamp"] > 0
assert len(assembler_data["checksum"]) > 0
def test_restore_no_data_available(instance_assembler):
"""Test restore when no data is available in ETCD"""
with patch.object(instance_assembler.etcd_client, 'restore_data', return_value=None):
result = instance_assembler.restore_data()
assert result is True
assert len(instance_assembler.instances) == 0
assert instance_assembler.ins_id_cnt == 1
def test_filter_abnormal_endpoints_all_normal(instance_assembler, test_config):
"""Test _filter_abnormal_endpoints filters endpoints when all node managers report normal status"""
instance = Instance(
job_name="test_filter_normal",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
instance.add_node_mgr("127.0.0.2", "8088")
with patch('motor.controller.core.instance_assembler.NodeManagerApiClient.query_status') as mock_query_status:
mock_query_status.return_value = {"status": True}
instance_assembler._filter_abnormal_endpoints(instance)
assert mock_query_status.call_count == 2
def test_filter_abnormal_endpoints_with_abnormal(instance_assembler, test_config):
"""Test _filter_abnormal_endpoints does not filter endpoints when node managers are reachable"""
instance = Instance(
job_name="test_filter_abnormal",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
instance.add_node_mgr("127.0.0.2", "8088")
endpoints1 = {1: Endpoint(id=1, ip="127.0.0.1", business_port="1001", mgmt_port="9001")}
endpoints2 = {2: Endpoint(id=2, ip="127.0.0.2", business_port="1002", mgmt_port="9002")}
instance.add_endpoints("127.0.0.1", endpoints1)
instance.add_endpoints("127.0.0.2", endpoints2)
with patch('motor.controller.core.instance_assembler.NodeManagerApiClient.query_status') as mock_query_status:
mock_query_status.side_effect = [{"status": True}, {"status": False}]
instance_assembler._filter_abnormal_endpoints(instance)
assert instance.get_endpoints_num() == 2
assert "127.0.0.1" in instance.endpoints
assert "127.0.0.2" in instance.endpoints
assert len(instance.node_managers) == 2
def test_filter_abnormal_endpoints_invalid_response(instance_assembler, test_config):
"""Test _filter_abnormal_endpoints does not filter endpoints when node manager responds (even with invalid response)"""
instance = Instance(
job_name="test_filter_invalid",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
endpoints = {1: Endpoint(id=1, ip="127.0.0.1", business_port="1001", mgmt_port="9001")}
instance.add_endpoints("127.0.0.1", endpoints)
with patch('motor.controller.core.instance_assembler.NodeManagerApiClient.query_status') as mock_query_status:
mock_query_status.return_value = {"invalid": "response"}
instance_assembler._filter_abnormal_endpoints(instance)
assert instance.get_endpoints_num() == 1
assert len(instance.node_managers) == 1
def test_filter_abnormal_endpoints_connection_error(instance_assembler, test_config):
"""Test _filter_abnormal_endpoints filters endpoints when connection to node manager fails"""
instance = Instance(
job_name="test_filter_error",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
endpoints = {1: Endpoint(id=1, ip="127.0.0.1", business_port="1001", mgmt_port="9001")}
instance.add_endpoints("127.0.0.1", endpoints)
with patch('motor.controller.core.instance_assembler.NodeManagerApiClient.query_status') as mock_query_status:
mock_query_status.side_effect = Exception("Connection failed")
instance_assembler._filter_abnormal_endpoints(instance)
assert instance.get_endpoints_num() == 0
assert len(instance.node_managers) == 0
def test_filter_abnormal_endpoints_mixed_scenarios(instance_assembler, test_config):
"""Test _filter_abnormal_endpoints with mixed reachable/unreachable node managers"""
instance = Instance(
job_name="test_filter_mixed",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
instance.add_node_mgr("127.0.0.2", "8088")
endpoints1 = {1: Endpoint(id=1, ip="127.0.0.1", business_port="1001", mgmt_port="9001")}
endpoints2 = {2: Endpoint(id=2, ip="127.0.0.2", business_port="1002", mgmt_port="9002")}
instance.add_endpoints("127.0.0.1", endpoints1)
instance.add_endpoints("127.0.0.2", endpoints2)
with patch('motor.controller.core.instance_assembler.NodeManagerApiClient.query_status') as mock_query_status:
mock_query_status.side_effect = [{"status": True}, Exception("Connection failed")]
instance_assembler._filter_abnormal_endpoints(instance)
assert instance.get_endpoints_num() == 1
assert "127.0.0.1" in instance.endpoints
assert "127.0.0.2" not in instance.endpoints
assert len(instance.node_managers) == 1
def test_filter_abnormal_endpoints_no_node_managers(instance_assembler, test_config, caplog):
"""Test _filter_abnormal_endpoints handles case when instance has no node managers"""
instance = Instance(
job_name="test_filter_no_managers",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
with caplog.at_level('WARNING'):
instance_assembler._filter_abnormal_endpoints(instance)
assert "No node managers found for instance test_filter_no_managers(id:1), cannot filter endpoints" in caplog.text
def test_assemble_instance_with_abnormal_endpoints(instance_assembler, test_config):
"""Test _assemble_instance when abnormal endpoints are removed leaving insufficient endpoints"""
instance = Instance(
job_name="test_assemble_abnormal",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
for i in range(1, 5):
pod_ip = f"127.0.0.{i}"
endpoints = {i: Endpoint(id=i, ip=pod_ip, business_port=f"100{i}", mgmt_port=f"900{i}")}
instance.add_endpoints(pod_ip, endpoints)
instance.add_node_mgr(pod_ip, "8088")
metadata = AssembleInstanceMetadata(instance=instance, register_timestamp=time.time())
def mock_filter(instance_to_filter):
if "127.0.0.1" in instance_to_filter.endpoints:
instance_to_filter.del_endpoints("127.0.0.1")
if "127.0.0.2" in instance_to_filter.endpoints:
instance_to_filter.del_endpoints("127.0.0.2")
with patch.object(instance_assembler, '_filter_abnormal_endpoints', side_effect=mock_filter):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status != RegisterStatus.ASSEMBLED
def test_assemble_instance_with_healthy_endpoints(instance_assembler, test_config):
"""Test _assemble_instance when endpoints are enough and all healthy"""
instance = Instance(
job_name="test_assemble_healthy",
model_name="test_model",
id=1,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
for i in range(1, 5):
pod_ip = f"127.0.0.{i}"
endpoints = {i: Endpoint(id=i, ip=pod_ip, business_port=f"100{i}", mgmt_port=f"900{i}")}
instance.add_endpoints(pod_ip, endpoints)
instance.add_node_mgr(pod_ip, "8088")
metadata = AssembleInstanceMetadata(instance=instance, register_timestamp=time.time())
with (
patch.object(instance_assembler, '_filter_abnormal_endpoints'),
patch('motor.controller.core.instance_assembler.InstanceManager') as mock_im_class,
):
mock_im = MagicMock()
mock_im_class.return_value = mock_im
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
mock_im.add_instance.assert_called_once_with(instance)
def test_is_endpoints_enough_multi_endpoint_disabled():
"""Test is_endpoints_enough when enable_multi_endpoints is False"""
instance1 = Instance(
job_name="test_not_enough_nodes",
model_name="test_model",
id=1,
role="both",
parallel_config=ParallelConfig(world_size=16),
enable_multi_endpoints=False,
)
instance1.add_node_mgr("127.0.0.1", "8080", device_num=8)
assert instance1.is_endpoints_enough() is False
instance2 = Instance(
job_name="test_enough_nodes",
model_name="test_model",
id=2,
role="both",
parallel_config=ParallelConfig(world_size=16),
enable_multi_endpoints=False,
)
instance2.add_node_mgr("127.0.0.1", "8080", device_num=8)
instance2.add_node_mgr("127.0.0.2", "8081", device_num=8)
assert instance2.is_endpoints_enough() is True
instance3 = Instance(
job_name="test_ceiling_nodes",
model_name="test_model",
id=3,
role="both",
parallel_config=ParallelConfig(world_size=20),
enable_multi_endpoints=False,
)
instance3.add_node_mgr("127.0.0.1", "8080", device_num=8)
instance3.add_node_mgr("127.0.0.2", "8081", device_num=8)
instance3.add_node_mgr("127.0.0.3", "8082", device_num=8)
assert instance3.is_endpoints_enough() is True
instance4 = Instance(
job_name="test_multi_endpoint",
model_name="test_model",
id=4,
role="both",
parallel_config=ParallelConfig(dp_size=4, tp_size=1, pp_size=1),
enable_multi_endpoints=True,
)
endpoints = {
0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000"),
1: Endpoint(id=1, ip="127.0.0.1", business_port="8001", mgmt_port="9001"),
}
instance4.add_endpoints("127.0.0.1", endpoints)
assert instance4.is_endpoints_enough() is False
endpoints2 = {
2: Endpoint(id=2, ip="127.0.0.2", business_port="8002", mgmt_port="9002"),
3: Endpoint(id=3, ip="127.0.0.2", business_port="8003", mgmt_port="9003"),
}
instance4.add_endpoints("127.0.0.2", endpoints2)
assert instance4.is_endpoints_enough() is True
def test_get_all_endpoints_multi_endpoint_disabled():
"""Test get_all_endpoints when enable_multi_endpoints is False"""
instance1 = Instance(
job_name="test_single_endpoint",
model_name="test_model",
id=1,
role="both",
parallel_config=ParallelConfig(dp_size=2, tp_size=1, pp_size=1),
enable_multi_endpoints=False,
)
endpoints = {
0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000"),
1: Endpoint(id=1, ip="127.0.0.1", business_port="8001", mgmt_port="9001"),
2: Endpoint(id=2, ip="127.0.0.1", business_port="8002", mgmt_port="9002"),
}
instance1.add_endpoints("127.0.0.1", endpoints)
all_eps = instance1.get_all_endpoints()
assert len(all_eps) == 1
assert all_eps[0].id == 0
instance2 = Instance(
job_name="test_all_endpoints",
model_name="test_model",
id=2,
role="both",
parallel_config=ParallelConfig(dp_size=3, tp_size=1, pp_size=1),
enable_multi_endpoints=True,
)
instance2.add_endpoints("127.0.0.1", endpoints)
all_eps2 = instance2.get_all_endpoints()
assert len(all_eps2) == 3
assert {ep.id for ep in all_eps2} == {0, 1, 2}
instance3 = Instance(
job_name="test_multi_pods",
model_name="test_model",
id=3,
role="both",
parallel_config=ParallelConfig(dp_size=2, tp_size=1, pp_size=1),
enable_multi_endpoints=False,
)
endpoints_pod1 = {
0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000"),
1: Endpoint(id=1, ip="127.0.0.1", business_port="8001", mgmt_port="9001"),
}
endpoints_pod2 = {
2: Endpoint(id=2, ip="127.0.0.2", business_port="8002", mgmt_port="9002"),
3: Endpoint(id=3, ip="127.0.0.2", business_port="8003", mgmt_port="9003"),
}
instance3.add_endpoints("127.0.0.1", endpoints_pod1)
instance3.add_endpoints("127.0.0.2", endpoints_pod2)
all_eps3 = instance3.get_all_endpoints()
assert len(all_eps3) == 1
assert all_eps3[0].id == 0
def test_assemble_instance_multi_endpoint_disabled(instance_assembler):
"""Test _assemble_instance when enable_multi_endpoints is False"""
instance = Instance(
job_name="test_assemble_multi_disabled",
model_name="test_model",
id=1,
role="both",
parallel_config=ParallelConfig(world_size=16),
enable_multi_endpoints=False,
)
instance.add_node_mgr("127.0.0.1", "8080", device_num=8)
instance.add_node_mgr("127.0.0.2", "8081", device_num=8)
endpoints1 = {0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")}
endpoints2 = {0: Endpoint(id=0, ip="127.0.0.2", business_port="8000", mgmt_port="9000")}
instance.add_endpoints("127.0.0.1", endpoints1)
instance.add_endpoints("127.0.0.2", endpoints2)
metadata = AssembleInstanceMetadata(instance=instance, register_timestamp=time.time())
with (
patch.object(instance_assembler, '_filter_abnormal_endpoints'),
patch('motor.controller.core.instance_assembler.InstanceManager') as mock_im_class,
):
mock_im = MagicMock()
mock_im_class.return_value = mock_im
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
mock_im.add_instance.assert_called_once_with(instance)
def test_assemble_instance_multi_endpoint_disabled_not_enough_nodes(instance_assembler):
"""Test _assemble_instance when enable_multi_endpoints is False but not enough nodes"""
instance = Instance(
job_name="test_assemble_not_enough",
model_name="test_model",
id=1,
role="both",
parallel_config=ParallelConfig(world_size=16),
enable_multi_endpoints=False,
)
instance.add_node_mgr("127.0.0.1", "8080", device_num=8)
endpoints = {0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")}
instance.add_endpoints("127.0.0.1", endpoints)
metadata = AssembleInstanceMetadata(instance=instance, register_timestamp=time.time())
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status != RegisterStatus.ASSEMBLED
def _make_mock_readonly_instance(
job_name: str,
role: str,
ips: list[str],
*,
endpoint_id: int = 0,
endpoint_ids: list[int] | None = None,
):
"""Create a ReadOnlyInstance wrapping a real Instance (required by upstream merge)."""
inst = Instance(
job_name=job_name,
model_name="test",
id=abs(hash(job_name)) % 1_000_000,
role=role,
parallel_config=ParallelConfig(),
)
for idx, ip in enumerate(ips):
ep_id = endpoint_ids[idx] if endpoint_ids is not None else endpoint_id
inst.add_endpoints(
f"pod-{job_name}-{idx}",
{0: Endpoint(id=ep_id, ip=ip, business_port="8000", mgmt_port="9000")},
)
return ReadOnlyInstance(inst)
def _make_mock_peer_instance_cross_node(job_name: str, role: str, ip_by_rank: dict[int, str]):
"""Peer instance with one pod per DP rank (cross-node DP topology)."""
inst = Instance(
job_name=job_name,
model_name="test",
id=abs(hash(job_name)) % 1_000_000,
role=role,
parallel_config=ParallelConfig(),
)
for dp_rank, ip in ip_by_rank.items():
inst.add_endpoints(
f"pod-{job_name}-{dp_rank}",
{0: Endpoint(id=dp_rank, ip=ip, business_port="8000", mgmt_port="9000")},
)
return ReadOnlyInstance(inst)
def test_collect_d2d_peer_ips_queries_active_only(instance_assembler, test_config):
"""_collect_d2d_peer_ips only queries ACTIVE instances from InstanceManager."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
with patch.object(InstanceManager(), 'get_instances', return_value=[]) as mock_get:
instance_assembler._collect_d2d_peer_ips(
metadata, [Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")]
)
mock_get.assert_called_once_with({InsStatus.ACTIVE})
def test_collect_d2d_peer_ips_matches_dp_rank(instance_assembler, test_config):
"""_collect_d2d_peer_ips collects only peer endpoints with matching id."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
mock_peer = _make_mock_peer_instance_cross_node(
"peer_job",
test_config['role'],
{0: "10.0.0.1", 1: "10.0.0.2", 2: "10.0.0.3", 3: "10.0.0.4"},
)
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
ep2 = Endpoint(id=2, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
ep5 = Endpoint(id=5, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[mock_peer]):
assert instance_assembler._collect_d2d_peer_ips(metadata, [ep0]) == ["0:10.0.0.1"]
assert instance_assembler._collect_d2d_peer_ips(metadata, [ep2]) == ["2:10.0.0.3"]
assert instance_assembler._collect_d2d_peer_ips(metadata, [ep5]) is None
def test_collect_d2d_peer_ips_active_same_role(instance_assembler, test_config):
"""_collect_d2d_peer_ips collects IPs from same-role ACTIVE peer instances."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
mock_peer = _make_mock_readonly_instance("peer_job", test_config['role'], ["10.0.0.1", "10.0.0.2"])
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[mock_peer]):
result = instance_assembler._collect_d2d_peer_ips(metadata, [ep0])
assert set(result) == {"0:10.0.0.1", "0:10.0.0.2"}
def test_collect_d2d_peer_ips_excludes_own_job_name(instance_assembler, test_config):
"""_collect_d2d_peer_ips excludes instances with the same job_name (self)."""
instance = Instance(
job_name="my_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
mock_self = _make_mock_readonly_instance("my_job", test_config['role'], ["10.0.0.1"])
mock_peer = _make_mock_readonly_instance("other_job", test_config['role'], ["10.0.0.2"])
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[mock_self, mock_peer]):
result = instance_assembler._collect_d2d_peer_ips(metadata, [ep0])
assert result == ["0:10.0.0.2"]
def test_collect_d2d_peer_ips_excludes_different_role(instance_assembler, test_config):
"""_collect_d2d_peer_ips excludes instances with a different role."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role="prefill",
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
mock_same = _make_mock_readonly_instance("peer_prefill", "prefill", ["10.0.0.1"])
mock_diff = _make_mock_readonly_instance("peer_decode", "decode", ["10.0.0.2"])
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[mock_same, mock_diff]):
result = instance_assembler._collect_d2d_peer_ips(metadata, [ep0])
assert result == ["0:10.0.0.1"]
def test_collect_d2d_peer_ips_deduplicates(instance_assembler, test_config):
"""_collect_d2d_peer_ips deduplicates IPs across multiple peer instances."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
mock_peer1 = _make_mock_readonly_instance("peer1", test_config['role'], ["10.0.0.1", "10.0.0.2"])
mock_peer2 = _make_mock_readonly_instance("peer2", test_config['role'], ["10.0.0.2", "10.0.0.3"])
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[mock_peer1, mock_peer2]):
result = instance_assembler._collect_d2d_peer_ips(metadata, [ep0])
assert set(result) == {"0:10.0.0.1", "0:10.0.0.2", "0:10.0.0.3"}
def test_collect_d2d_peer_ips_no_peers(instance_assembler, test_config):
"""_collect_d2d_peer_ips returns None when no peer instances exist."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
metadata = AssembleInstanceMetadata(instance=instance)
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[]):
result = instance_assembler._collect_d2d_peer_ips(metadata, [ep0])
assert result is None
def test_send_start_command_with_d2d_enabled(instance_assembler, test_config):
"""_send_start_command includes rank-aligned d2d_peer_ips when D2D is enabled."""
instance = Instance(
job_name="d2d_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
instance.add_node_mgr("127.0.0.2", "8088")
reg_msg = create_register_msg("d2d_job", "127.0.0.1", test_config)
pod_endpoints = instance_assembler._build_single_endpoint(reg_msg, 0)
instance.add_endpoints("127.0.0.1", pod_endpoints)
reg_msg2 = create_register_msg("d2d_job", "127.0.0.2", test_config)
pod_endpoints2 = instance_assembler._build_single_endpoint(reg_msg2, 1)
instance.add_endpoints("127.0.0.2", pod_endpoints2)
metadata = AssembleInstanceMetadata(instance=instance)
mock_peer = _make_mock_peer_instance_cross_node(
"peer_job",
test_config['role'],
{0: "10.0.0.1", 1: "10.0.0.2"},
)
with (
patch.object(instance_assembler, '_is_d2d_enabled_for_role', return_value=True),
patch.object(InstanceManager(), 'get_instances', return_value=[mock_peer]),
patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send,
):
mock_send.return_value = True
result = instance_assembler._send_start_command(metadata)
assert result is True
assert mock_send.call_count == 2
first_msg = mock_send.call_args_list[0][0][1]
second_msg = mock_send.call_args_list[1][0][1]
assert first_msg.d2d_peer_ips == ["0:10.0.0.1"]
assert second_msg.d2d_peer_ips == ["1:10.0.0.2"]
def test_send_start_command_with_d2d_disabled(instance_assembler, test_config):
"""_send_start_command does not populate d2d_peer_ips when D2D is disabled."""
instance = Instance(
job_name="d2d_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
reg_msg = create_register_msg("d2d_job", "127.0.0.1", test_config)
pod_endpoints = instance_assembler._build_single_endpoint(reg_msg, 0)
instance.add_endpoints("127.0.0.1", pod_endpoints)
metadata = AssembleInstanceMetadata(instance=instance)
mock_peer = _make_mock_readonly_instance("peer_job", test_config['role'], ["10.0.0.1"])
with (
patch.object(instance_assembler, '_is_d2d_enabled_for_role', return_value=False),
patch.object(InstanceManager(), 'get_instances', return_value=[mock_peer]),
patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send,
):
mock_send.return_value = True
instance_assembler._send_start_command(metadata)
called_msg = mock_send.call_args[0][1]
assert called_msg.d2d_peer_ips is None
def test_send_start_command_with_d2d_enabled_no_peers(instance_assembler, test_config):
"""_send_start_command omits d2d_peer_ips when D2D is enabled but no peers found."""
instance = Instance(
job_name="d2d_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
)
instance.add_node_mgr("127.0.0.1", "8088")
reg_msg = create_register_msg("d2d_job", "127.0.0.1", test_config)
pod_endpoints = instance_assembler._build_single_endpoint(reg_msg, 0)
instance.add_endpoints("127.0.0.1", pod_endpoints)
metadata = AssembleInstanceMetadata(instance=instance)
with (
patch.object(instance_assembler, '_is_d2d_enabled_for_role', return_value=True),
patch.object(InstanceManager(), 'get_instances', return_value=[]),
patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command'
) as mock_send,
):
mock_send.return_value = True
instance_assembler._send_start_command(metadata)
called_msg = mock_send.call_args[0][1]
assert called_msg.d2d_peer_ips is None
def test_collect_d2d_peer_ips_includes_headless(instance_assembler, test_config):
"""_collect_d2d_peer_ips includes headless peer endpoints for CP cross-node."""
instance = Instance(
job_name="current_job",
model_name="test",
id=99,
role=test_config['role'],
parallel_config=test_config['parallel_config'],
enable_multi_endpoints=True,
)
metadata = AssembleInstanceMetadata(instance=instance)
mock_peer = Instance(
job_name="peer_job",
model_name="test",
id=88,
role=test_config['role'],
parallel_config=ParallelConfig(),
enable_multi_endpoints=True,
)
mock_peer.add_endpoints(
"10.0.0.1",
{0: Endpoint(id=0, ip="10.0.0.1", business_port="8000", mgmt_port="9000")},
)
mock_peer.add_endpoints(
"10.0.0.2",
{0: Endpoint(id=1, ip="10.0.0.2", business_port="8000", mgmt_port="9000", headless=True)},
)
ro_peer = ReadOnlyInstance(mock_peer)
ep0 = Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
ep1 = Endpoint(id=1, ip="127.0.0.1", business_port="8000", mgmt_port="9000")
with patch.object(InstanceManager(), 'get_instances', return_value=[ro_peer]):
assert instance_assembler._collect_d2d_peer_ips(metadata, [ep0]) == ["0:10.0.0.1"]
assert instance_assembler._collect_d2d_peer_ips(metadata, [ep1]) == ["1:10.0.0.2"]
def test_cross_node_pcp_assembly_waits_for_all_nodes(instance_assembler):
"""Test that nnodes > 1 waits for all nodes before assembling"""
instance = Instance(
job_name="test_pcp_wait",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=1, tp_size=4),
enable_multi_endpoints=True,
)
instance.add_node_mgr("127.0.0.1", "8080", device_num=8)
endpoints = {0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")}
instance.add_endpoints("127.0.0.1", endpoints)
metadata = AssembleInstanceMetadata(
instance=instance,
register_timestamp=time.time(),
nnodes=2,
)
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status != RegisterStatus.ASSEMBLED
instance.add_node_mgr("127.0.0.2", "8080", device_num=8)
endpoints2 = {0: Endpoint(id=1, ip="127.0.0.2", business_port="8000", mgmt_port="9000")}
instance.add_endpoints("127.0.0.2", endpoints2)
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
def test_nnodes_default_backward_compatible(instance_assembler):
"""Test that nnodes=1 (default) uses existing is_endpoints_enough logic"""
instance = Instance(
job_name="test_nnodes_default",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=2, tp_size=4),
enable_multi_endpoints=True,
)
instance.add_node_mgr("127.0.0.1", "8080", device_num=8)
endpoints = {0: Endpoint(id=0, ip="127.0.0.1", business_port="8000", mgmt_port="9000")}
instance.add_endpoints("127.0.0.1", endpoints)
metadata = AssembleInstanceMetadata(
instance=instance,
register_timestamp=time.time(),
nnodes=1,
)
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status != RegisterStatus.ASSEMBLED
def test_cross_node_pcp_assembly_extra_nodes(instance_assembler):
"""Test that node_managers > nnodes still assembles (tolerant)"""
instance = Instance(
job_name="test_pcp_extra",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=1, tp_size=4),
enable_multi_endpoints=True,
)
for ip_suffix in ["1", "2", "3"]:
instance.add_node_mgr(f"127.0.0.{ip_suffix}", "8080", device_num=8)
endpoints = {
0: Endpoint(id=int(ip_suffix) - 1, ip=f"127.0.0.{ip_suffix}", business_port="8000", mgmt_port="9000")
}
instance.add_endpoints(f"127.0.0.{ip_suffix}", endpoints)
metadata = AssembleInstanceMetadata(
instance=instance,
register_timestamp=time.time(),
nnodes=2,
)
with patch.object(instance_assembler, '_filter_abnormal_endpoints'):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
def test_cross_node_pcp_with_dp_waits_for_all_groups(instance_assembler):
"""DP=4, PCP nnodes=2: needs dp*nnodes=8 nodes, not just 2."""
instance = Instance(
job_name="test_pcp_dp_combo",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=4, tp_size=16, pcp_size=2),
enable_multi_endpoints=True,
)
for i in range(7):
instance.add_node_mgr(f"10.0.0.{i + 1}", "8080", device_num=16)
instance.add_endpoints(
f"10.0.0.{i + 1}",
{0: Endpoint(id=i, ip=f"10.0.0.{i + 1}", business_port="8000", mgmt_port="9000")},
)
metadata = AssembleInstanceMetadata(instance=instance, nnodes=2)
with patch.object(instance_assembler, "_filter_abnormal_endpoints"):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status != RegisterStatus.ASSEMBLED
instance.add_node_mgr("10.0.0.8", "8080", device_num=16)
instance.add_endpoints("10.0.0.8", {0: Endpoint(id=7, ip="10.0.0.8", business_port="8000", mgmt_port="9000")})
with patch.object(instance_assembler, "_filter_abnormal_endpoints"):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
def test_send_start_command_assigns_node_rank(instance_assembler):
"""Test that _send_start_command assigns node_rank by registration order"""
instance = Instance(
job_name="test_node_rank",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=1, tp_size=4),
enable_multi_endpoints=True,
)
instance.add_node_mgr("10.0.0.2", "8080", device_num=8)
instance.add_node_mgr("10.0.0.1", "8080", device_num=8)
instance.add_node_mgr("10.0.0.3", "8080", device_num=8)
instance.add_endpoints("10.0.0.2", {0: Endpoint(id=0, ip="10.0.0.2", business_port="8000", mgmt_port="9000")})
instance.add_endpoints("10.0.0.1", {0: Endpoint(id=0, ip="10.0.0.1", business_port="8000", mgmt_port="9000")})
instance.add_endpoints("10.0.0.3", {0: Endpoint(id=0, ip="10.0.0.3", business_port="8000", mgmt_port="9000")})
metadata = AssembleInstanceMetadata(instance=instance, nnodes=3)
sent_msgs = []
def capture_call(node_mgr, start_cmd_msg):
sent_msgs.append((node_mgr.pod_ip, start_cmd_msg.node_rank))
return True
with patch(
'motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command',
side_effect=capture_call,
):
instance_assembler._send_start_command(metadata)
assert len(sent_msgs) == 3
assert sent_msgs[0] == ("10.0.0.2", 0)
assert sent_msgs[1] == ("10.0.0.1", 1)
assert sent_msgs[2] == ("10.0.0.3", 2)
def test_send_start_command_node_rank_modulo_for_dp_pcp(instance_assembler):
"""DP=2, nnodes=2, 4 nodes: node_rank = registration_index % nnodes (0,1,0,1)."""
instance = Instance(
job_name="test_node_rank_mod",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=2, tp_size=4, pcp_size=2),
enable_multi_endpoints=True,
)
for i in range(4):
instance.add_node_mgr(f"10.0.0.{i + 1}", "8080", device_num=8)
instance.add_endpoints(
f"10.0.0.{i + 1}",
{0: Endpoint(id=i, ip=f"10.0.0.{i + 1}", business_port="8000", mgmt_port="9000")},
)
metadata = AssembleInstanceMetadata(instance=instance, nnodes=2)
sent_ranks = {}
def capture_call(node_mgr, start_cmd_msg):
sent_ranks[node_mgr.pod_ip] = start_cmd_msg.node_rank
return True
with patch(
"motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command",
side_effect=capture_call,
):
instance_assembler._send_start_command(metadata)
assert sent_ranks["10.0.0.1"] == 0
assert sent_ranks["10.0.0.2"] == 1
assert sent_ranks["10.0.0.3"] == 0
assert sent_ranks["10.0.0.4"] == 1
def test_register_msg_nnodes_stored_in_metadata(instance_assembler, test_config):
"""Test that nnodes from RegisterMsg is stored in AssembleInstanceMetadata"""
job_name = "test_nnodes_stored"
msg = create_register_msg(
job_name,
test_config['pod_ip1'],
test_config,
nnodes=3,
)
result = instance_assembler.register(msg)
assert result == 0
metadata = instance_assembler.instances[job_name]
assert metadata.nnodes == 3, f"Expected nnodes=3, got {metadata.nnodes}"
def test_register_msg_nnodes_default(instance_assembler, test_config):
"""Test that nnodes defaults to 1 when not specified"""
job_name = "test_nnodes_default_reg"
msg = create_register_msg(job_name, test_config['pod_ip1'], test_config)
result = instance_assembler.register(msg)
assert result == 0
metadata = instance_assembler.instances[job_name]
assert metadata.nnodes == 1, f"Expected default nnodes=1, got {metadata.nnodes}"
def test_reregister_preserves_nnodes(instance_assembler, test_config):
"""Test that reregister NOT_REGISTERED path preserves nnodes from ReregisterMsg."""
job_name = "test_reregister_nnodes"
config = test_config.copy()
register_msg = create_register_msg(job_name, config['pod_ip1'], config, nnodes=2)
result = instance_assembler.register(register_msg)
assert result == 0
assert instance_assembler.instances[job_name].nnodes == 2
instance_assembler.instances.clear()
endpoint = Endpoint(id=0, ip=config['pod_ip1'], business_port="8000", mgmt_port="9000")
reregister_msg = ReregisterMsg(
job_name=job_name,
model_name="test_model",
instance_id=1,
role=config['role'],
pod_ip=config['pod_ip1'],
nm_port="8088",
parallel_config=config['parallel_config'],
endpoints=[endpoint],
enable_multi_endpoints=True,
nnodes=2,
)
result = instance_assembler.reregister(reregister_msg)
assert result == 0
assert instance_assembler.instances[job_name].nnodes == 2
def test_cross_node_pcp_marks_slave_endpoints_headless(instance_assembler):
"""When nnodes > 1, slave node endpoints (node_rank > 0) are marked headless."""
instance = Instance(
job_name="test_headless_marking",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=1, tp_size=4),
enable_multi_endpoints=True,
)
instance.add_node_mgr("10.0.0.10", "8080", device_num=8)
instance.add_node_mgr("10.0.0.2", "8080", device_num=8)
instance.add_endpoints("10.0.0.10", {0: Endpoint(id=0, ip="10.0.0.10", business_port="8000", mgmt_port="9000")})
instance.add_endpoints("10.0.0.2", {0: Endpoint(id=0, ip="10.0.0.2", business_port="8000", mgmt_port="9000")})
metadata = AssembleInstanceMetadata(instance=instance, nnodes=2)
with patch.object(instance_assembler, "_filter_abnormal_endpoints"):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
eps_master = instance.get_endpoints("10.0.0.10")
for ep in eps_master.values():
assert ep.headless is False, f"Master endpoint {ep.ip} should not be headless"
eps_slave = instance.get_endpoints("10.0.0.2")
for ep in eps_slave.values():
assert ep.headless is True, f"Slave endpoint {ep.ip} should be headless"
all_eps = instance.get_all_endpoints()
assert len(all_eps) == 1
assert all_eps[0].ip == "10.0.0.10"
def test_cross_node_pcp_reregister_preserves_headless(instance_assembler):
"""Re-registration uses node_rank from ReregisterMsg, not registration order."""
slave_endpoint = Endpoint(id=0, ip="10.0.0.200", business_port="8000", mgmt_port="9000")
master_endpoint = Endpoint(id=1, ip="10.0.0.1", business_port="8000", mgmt_port="9000")
slave_msg = ReregisterMsg(
job_name="test_reregister_headless",
model_name="test_model",
instance_id=1,
role="prefill",
pod_ip="10.0.0.200",
nm_port="8088",
parallel_config=ParallelConfig(dp_size=1, tp_size=4),
endpoints=[slave_endpoint],
nnodes=2,
node_rank=1,
)
instance_assembler.reregister(slave_msg)
master_msg = ReregisterMsg(
job_name="test_reregister_headless",
model_name="test_model",
instance_id=1,
role="prefill",
pod_ip="10.0.0.1",
nm_port="8088",
parallel_config=ParallelConfig(dp_size=1, tp_size=4),
endpoints=[master_endpoint],
nnodes=2,
node_rank=0,
)
instance_assembler.reregister(master_msg)
metadata = instance_assembler.instances["test_reregister_headless"]
assert metadata.nnodes == 2
assert metadata.is_reregister is True
eps_slave = metadata.instance.get_endpoints("10.0.0.200")
for ep in eps_slave.values():
assert ep.headless is True
eps_master = metadata.instance.get_endpoints("10.0.0.1")
for ep in eps_master.values():
assert ep.headless is False
def test_cross_node_pcp_no_headless_when_nnodes_is_one(instance_assembler):
"""When nnodes=1, no endpoints are marked headless."""
instance = Instance(
job_name="test_no_headless_nnodes1",
model_name="test_model",
id=1,
role="prefill",
parallel_config=ParallelConfig(dp_size=2, tp_size=4),
enable_multi_endpoints=True,
)
instance.add_node_mgr("10.0.0.1", "8080", device_num=8)
instance.add_node_mgr("10.0.0.2", "8080", device_num=8)
instance.add_endpoints("10.0.0.1", {0: Endpoint(id=0, ip="10.0.0.1", business_port="8000", mgmt_port="9000")})
instance.add_endpoints("10.0.0.2", {0: Endpoint(id=1, ip="10.0.0.2", business_port="8000", mgmt_port="9000")})
metadata = AssembleInstanceMetadata(instance=instance, nnodes=1)
with patch.object(instance_assembler, "_filter_abnormal_endpoints"):
instance_assembler._assemble_instance(metadata)
assert metadata.register_status == RegisterStatus.ASSEMBLED
for ep in instance.get_all_endpoints():
assert ep.headless is False
def test_register_records_snapshot_dp_master_ip_when_is_master(instance_assembler, test_config):
"""Register with is_master=True records snapshot_dp_master_ip on instance metadata."""
job_name = "test_snapshot_master_register"
slave_msg = create_register_msg(job_name, "10.0.0.2", test_config, is_master=False)
master_msg = create_register_msg(job_name, "10.0.0.1", test_config, is_master=True)
assert instance_assembler.register(slave_msg) == 0
assert instance_assembler.register(master_msg) == 0
metadata = instance_assembler.instances[job_name]
assert metadata.snapshot_dp_master_ip == "10.0.0.1"
def test_send_start_command_uses_snapshot_dp_master_ip(instance_assembler, test_config):
"""Start command uses snapshot_dp_master_ip instead of first registered node."""
job_name = "test_snapshot_master_start"
metadata = create_assembled_instance(instance_assembler, job_name, test_config)
metadata.snapshot_dp_master_ip = "10.0.0.99"
with patch(
"motor.controller.api_client.node_manager_api_client.NodeManagerApiClient.send_start_command"
) as mock_send:
mock_send.return_value = True
assert instance_assembler._send_start_command(metadata) is True
master_ips = {call.args[1].master_dp_ip for call in mock_send.call_args_list}
assert master_ips == {"10.0.0.99"}