import os
import sys
import pytest
import signal
from unittest.mock import patch, MagicMock
os.environ["USER_CONFIG_PATH"] = "tests/jsons/useruser_config.json".replace("\\", "/")
os.environ["ROLE"] = "both"
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from motor.node_manager.core.engine_manager import EngineManager
from motor.node_manager.api_client.controller_api_client import ControllerApiClient
from motor.config.node_manager import NodeManagerConfig
from motor.common.resources.http_msg_spec import StartCmdMsg, RegisterMsg, ReregisterMsg
from motor.common.resources.endpoint import Endpoint
from motor.common.resources.instance import ParallelConfig, PDRole
from tests.node_manager.conftest import apply_node_manager_test_config, create_config_mock
@pytest.fixture(name="engine_manager")
def _engine_manager_fixture(config_data):
"""Create EngineManager instance with mocked config"""
with (
patch("motor.config.node_manager.safe_open") as mock_safe_open,
patch("threading.Thread") as mock_thread_class,
patch.dict("os.environ", {"JOB_NAME": "test_job", "CONFIG_PATH": "tests/jsons", "ROLE": "both"}),
):
mock_safe_open.side_effect = create_config_mock(config_data)
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
if hasattr(EngineManager, "_instances") and EngineManager in EngineManager._instances:
if EngineManager in EngineManager._instances:
del EngineManager._instances[EngineManager]
config = NodeManagerConfig()
apply_node_manager_test_config(config, config_data)
manager = EngineManager(config)
manager._register_thread = MagicMock()
manager._register_thread.is_alive.return_value = False
yield manager
@pytest.fixture(name="sample_endpoints")
def _sample_endpoints_fixture():
"""Create sample endpoints"""
return [
Endpoint(id=0, ip="192.168.1.100", business_port="8080", mgmt_port="9090"),
Endpoint(id=1, ip="192.168.1.100", business_port="8081", mgmt_port="9091"),
]
@pytest.fixture(name="sample_start_cmd_msg")
def _sample_start_cmd_msg_fixture(sample_endpoints):
"""Create sample StartCmdMsg"""
return StartCmdMsg(
job_name="test_job",
role="both",
instance_id=1,
endpoints=sample_endpoints,
master_dp_ip="192.168.1.100",
)
class TestEngineManager:
@patch("motor.config.node_manager.safe_open")
@patch("threading.Thread")
@patch.dict("os.environ", {"JOB_NAME": "test_job", "CONFIG_PATH": "./", "ROLE": "both"})
def test_init_success(self, mock_thread_class, mock_safe_open, config_data):
"""Test EngineManager initialization"""
mock_safe_open.side_effect = create_config_mock(config_data)
mock_thread = MagicMock()
mock_thread_class.return_value = mock_thread
if hasattr(EngineManager, "_instances") and EngineManager in EngineManager._instances:
if EngineManager in EngineManager._instances:
del EngineManager._instances[EngineManager]
config = NodeManagerConfig()
manager = EngineManager(config)
assert manager.endpoints == []
assert manager.instance_id == 0
assert manager.is_working is False
assert hasattr(manager, "_config")
mock_thread_class.assert_called_once()
@patch("motor.config.node_manager.safe_open")
@patch("threading.Thread")
@patch.dict("os.environ", {"JOB_NAME": "test_job", "CONFIG_PATH": "./", "ROLE": "both"})
def test_singleton_pattern(self, mock_thread_class, mock_safe_open, config_data):
"""Test singleton pattern"""
mock_safe_open.side_effect = create_config_mock(config_data)
mock_thread_class.return_value = MagicMock()
if hasattr(EngineManager, "_instances") and EngineManager in EngineManager._instances:
if EngineManager in EngineManager._instances:
del EngineManager._instances[EngineManager]
config = NodeManagerConfig()
manager1 = EngineManager(config)
manager2 = EngineManager(config)
assert manager1 is manager2
def test_check_config_paras_success(self, engine_manager):
"""Test _check_config_paras with valid config"""
engine_manager._config.basic_config.job_name = "test_job"
assert engine_manager._check_config_paras() is True
def test_check_config_paras_failure(self, engine_manager):
"""Test _check_config_paras with None job_name"""
engine_manager._config.basic_config.job_name = None
result = engine_manager._check_config_paras()
assert result in [True, False]
def test_gen_register_msg_success(self, engine_manager):
"""Test _gen_register_msg with valid config"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.model_name = "test_model"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.endpoint_config.service_ports = ["8080", "8081"]
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager._config.basic_config.enable_multi_endpoints = True
msg = engine_manager._gen_register_msg()
if msg is not None:
assert isinstance(msg, RegisterMsg)
assert msg.job_name == "test_job"
assert msg.model_name == "test_model"
assert msg.role == PDRole.ROLE_U
assert msg.enable_multi_endpoints is True
assert msg.is_master is False
else:
pass
def test_gen_register_msg_includes_is_snapshot_master(self, engine_manager):
"""Test _gen_register_msg propagates is_snapshot_master as is_master."""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.model_name = "test_model"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.endpoint_config.service_ports = ["8080"]
engine_manager._config.endpoint_config.mgmt_ports = ["8081"]
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager._config.basic_config.enable_multi_endpoints = True
engine_manager._config.basic_config.device_num = 8
engine_manager.is_snapshot_master = True
msg = engine_manager._gen_register_msg()
assert msg is not None
assert msg.is_master is True
def test_gen_register_msg_failure(self, engine_manager):
"""Test _gen_register_msg with invalid config"""
engine_manager._config.basic_config.job_name = None
msg = engine_manager._gen_register_msg()
assert msg is None
def test_gen_reregister_msg_success(self, engine_manager, sample_endpoints):
"""Test _gen_reregister_msg with valid data"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager.endpoints = sample_endpoints
engine_manager.instance_id = 1
msg = engine_manager._gen_reregister_msg()
assert msg is not None
assert isinstance(msg, ReregisterMsg)
assert msg.job_name == "test_job"
assert msg.instance_id == 1
assert msg.enable_multi_endpoints is True
assert len(msg.endpoints) == 2
def test_gen_reregister_msg_failure_no_endpoints(self, engine_manager):
"""Test _gen_reregister_msg with empty endpoints"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager.endpoints = []
engine_manager.instance_id = 1
msg = engine_manager._gen_reregister_msg()
assert msg is None
def test_gen_reregister_msg_failure_no_instance_id(self, engine_manager, sample_endpoints):
"""Test _gen_reregister_msg with None instance_id"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager.endpoints = sample_endpoints
engine_manager.instance_id = None
with pytest.raises(TypeError):
engine_manager._gen_reregister_msg()
@patch("motor.node_manager.core.engine_manager.ControllerApiClient.register")
def test_post_register_msg_success(self, mock_register, engine_manager):
"""Test post_register_msg with successful response"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.model_name = "test_model"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.endpoint_config.service_ports = ["8080"]
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.api_config.coordinator_api_dns = "localhost"
engine_manager._config.api_config.coordinator_api_mgmt_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
mock_register.return_value = True
result = engine_manager.post_register_msg()
assert result is True
mock_register.assert_called_once()
@patch("motor.node_manager.core.engine_manager.ControllerApiClient.register")
def test_post_register_msg_failure(self, mock_register, engine_manager):
"""Test post_register_msg with exception"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.model_name = "test_model"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.endpoint_config.service_ports = ["8080"]
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.api_config.coordinator_api_dns = "localhost"
engine_manager._config.api_config.coordinator_api_mgmt_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
mock_register.return_value = False
result = engine_manager.post_register_msg()
assert result is False
@patch("motor.node_manager.core.engine_manager.ControllerApiClient.re_register")
def test_post_reregister_msg_success(self, mock_re_register, engine_manager, sample_endpoints):
"""Test post_reregister_msg with successful response"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager.endpoints = sample_endpoints
engine_manager.instance_id = 1
mock_re_register.return_value = True
result = engine_manager.post_reregister_msg()
assert result is True
mock_re_register.assert_called_once()
@patch("motor.node_manager.core.engine_manager.ControllerApiClient.re_register")
def test_post_reregister_msg_failure(self, mock_re_register, engine_manager, sample_endpoints):
"""Test post_reregister_msg with exception"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.basic_config.role = PDRole.ROLE_U
engine_manager._config.api_config.pod_ip = "192.168.1.100"
engine_manager._config.api_config.host_ip = "192.168.1.200"
engine_manager._config.api_config.node_manager_port = 8080
engine_manager._config.basic_config.parallel_config = ParallelConfig(tp_size=2, pp_size=1)
engine_manager.endpoints = sample_endpoints
engine_manager.instance_id = 1
mock_re_register.return_value = False
result = engine_manager.post_reregister_msg()
assert result is False
def test_check_cmd_para_success(self, engine_manager, sample_start_cmd_msg):
"""Test _check_cmd_para with valid command"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.endpoint_config.endpoint_num = 2
engine_manager._config.api_config.pod_ip = "192.168.1.100"
assert engine_manager._check_cmd_para(sample_start_cmd_msg) is True
@pytest.mark.parametrize(
"job_name,endpoint_num,pod_ip,expected",
[
("wrong_job", 2, "192.168.1.100", False),
("test_job", 1, "192.168.1.100", False),
("test_job", 2, "192.168.1.101", False),
],
)
def test_check_cmd_para_failure(
self, engine_manager, sample_start_cmd_msg, job_name, endpoint_num, pod_ip, expected
):
"""Test _check_cmd_para with invalid parameters"""
engine_manager._config.basic_config.job_name = job_name
engine_manager._config.endpoint_config.endpoint_num = endpoint_num
engine_manager._config.api_config.pod_ip = pod_ip
assert engine_manager._check_cmd_para(sample_start_cmd_msg) == expected
def test_parse_start_cmd_success(self, engine_manager, sample_start_cmd_msg):
"""Test parse_start_cmd with valid command"""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.endpoint_config.endpoint_num = 2
engine_manager._config.api_config.pod_ip = "192.168.1.100"
result = engine_manager.parse_start_cmd(sample_start_cmd_msg)
assert result is True
assert engine_manager.instance_id == 1
assert len(engine_manager.endpoints) == 2
def test_stop(self, engine_manager):
"""Test stop method"""
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
engine_manager._register_thread = mock_thread
engine_manager.stop()
mock_thread.join.assert_called_once_with(timeout=2.0)
@patch("motor.node_manager.core.engine_manager.wait_until_api_ready", return_value=True)
@patch("motor.node_manager.core.engine_manager.time.sleep")
@patch("motor.node_manager.core.engine_manager.EngineManager.post_register_msg")
@patch("motor.node_manager.core.engine_manager.os.kill")
def test_register_retry_mechanism(
self, mock_kill, mock_post_register, mock_sleep, _mock_wait_api_ready, engine_manager
):
"""Test registration retry mechanism"""
mock_sleep.return_value = None
mock_post_register.return_value = False
engine_manager._register()
assert mock_post_register.call_count == 5
mock_kill.assert_called_once_with(os.getpid(), signal.SIGTERM)
@patch("motor.node_manager.core.engine_manager.wait_until_api_ready", return_value=True)
@patch("motor.node_manager.core.engine_manager.EngineManager.post_register_msg")
@patch("motor.node_manager.core.engine_manager.time.sleep")
def test_register_success_on_first_attempt(
self, mock_sleep, mock_post_register, _mock_wait_api_ready, engine_manager
):
"""Test registration succeeds on first attempt"""
mock_post_register.return_value = True
engine_manager._register()
assert mock_post_register.call_count == 1
mock_sleep.assert_not_called()
@patch("motor.node_manager.core.engine_manager.wait_until_api_ready", return_value=True)
@patch("motor.node_manager.core.engine_manager.EngineManager.post_register_msg")
@patch("motor.node_manager.core.engine_manager.time.sleep")
def test_register_success_on_retry(self, mock_sleep, mock_post_register, _mock_wait_api_ready, engine_manager):
"""Test registration succeeds on retry"""
mock_post_register.side_effect = [False, True]
engine_manager._register()
assert mock_post_register.call_count == 2
assert mock_sleep.call_count == 1
class TestD2DWeightTransfer:
"""Tests for D2D weight transfer peer IP handling in EngineManager."""
def test_d2d_peer_ips_initialized_none(self, engine_manager):
"""d2d_peer_ips is initialized as None in __init__."""
assert engine_manager.d2d_peer_ips is None
def test_parse_start_cmd_with_d2d_peer_ips(self, engine_manager, sample_start_cmd_msg):
"""parse_start_cmd extracts d2d_peer_ips from StartCmdMsg."""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.endpoint_config.endpoint_num = 2
engine_manager._config.api_config.pod_ip = "192.168.1.100"
sample_start_cmd_msg.d2d_peer_ips = ["10.0.0.1", "10.0.0.2"]
result = engine_manager.parse_start_cmd(sample_start_cmd_msg)
assert result is True
assert engine_manager.d2d_peer_ips == ["10.0.0.1", "10.0.0.2"]
def test_parse_start_cmd_with_empty_d2d_peer_ips(self, engine_manager, sample_start_cmd_msg):
"""parse_start_cmd handles empty d2d_peer_ips list."""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.endpoint_config.endpoint_num = 2
engine_manager._config.api_config.pod_ip = "192.168.1.100"
sample_start_cmd_msg.d2d_peer_ips = []
result = engine_manager.parse_start_cmd(sample_start_cmd_msg)
assert result is True
assert engine_manager.d2d_peer_ips == []
def test_parse_start_cmd_with_default_d2d_peer_ips(self, engine_manager, sample_endpoints):
"""parse_start_cmd handles StartCmdMsg with default (None) d2d_peer_ips."""
engine_manager._config.basic_config.job_name = "test_job"
engine_manager._config.endpoint_config.endpoint_num = 2
engine_manager._config.api_config.pod_ip = "192.168.1.100"
msg = StartCmdMsg(
job_name="test_job",
role="both",
instance_id=1,
endpoints=sample_endpoints,
master_dp_ip="192.168.1.100",
)
result = engine_manager.parse_start_cmd(msg)
assert result is True
assert engine_manager.d2d_peer_ips is None
def test_start_delegates_to_fault_reporter(self, engine_manager):
engine_manager._fault_reporter = MagicMock()
engine_manager.start()
engine_manager._fault_reporter.start.assert_called_once()
def test_stop_delegates_to_fault_reporter(self, engine_manager):
engine_manager._fault_reporter = MagicMock()
engine_manager.stop()
engine_manager._fault_reporter.stop.assert_called_once()
def test_update_config_delegates_to_fault_reporter(self, engine_manager):
engine_manager._fault_reporter = MagicMock()
new_config = NodeManagerConfig()
engine_manager.update_config(new_config)
engine_manager._fault_reporter.update_config.assert_called_once_with(new_config, engine_manager.endpoints)
class TestSnapshotSupport:
"""Tests for snapshot restore helpers added in EngineManager."""
def test_get_snapshot_metadata_path_uses_custom_path(self, engine_manager):
engine_manager._config.snapshot_config.snapshot_metadata_path = "/custom/snapshot_metadata.json"
assert engine_manager.get_snapshot_metadata_path() == "/custom/snapshot_metadata.json"
def test_get_snapshot_metadata_path_returns_default(self, engine_manager):
from motor.common.utils.snapshot_utils import MOTOR_SNAPSHOT_METADATA_PATH
engine_manager._config.snapshot_config.snapshot_metadata_path = ""
with patch("motor.node_manager.core.engine_manager.os.path.exists", return_value=False):
assert engine_manager.get_snapshot_metadata_path() == MOTOR_SNAPSHOT_METADATA_PATH
@patch("motor.node_manager.core.engine_manager.update_snapshot_metadata")
@patch("motor.node_manager.core.engine_manager.load_snapshot_metadata")
@patch("motor.node_manager.core.engine_manager.os.makedirs")
def test_engine_suspend_prepare_initializes_metadata(
self, mock_makedirs, mock_load, mock_update, engine_manager, tmp_path
):
from motor.common.utils.snapshot_utils import MOTOR_SNAPSHOT_WEIGHT_DIR
metadata_path = str(tmp_path / "snapshot_metadata.json")
engine_manager._config.snapshot_config.enable_snapshot = True
engine_manager._config.snapshot_config.snapshot_metadata_path = ""
mock_load.side_effect = ValueError("missing field")
with patch.object(engine_manager, "get_snapshot_metadata_path", return_value=metadata_path):
engine_manager.engine_suspend_prepare()
mock_makedirs.assert_called()
mock_update.assert_called_once_with(metadata_path, "model_save_path", MOTOR_SNAPSHOT_WEIGHT_DIR)
assert os.path.exists(metadata_path)
def test_engine_suspend_prepare_skipped_when_snapshot_disabled(self, engine_manager):
engine_manager._config.snapshot_config.enable_snapshot = False
with patch("motor.node_manager.core.engine_manager.os.makedirs") as mock_makedirs:
engine_manager.engine_suspend_prepare()
mock_makedirs.assert_not_called()
@patch("motor.node_manager.core.engine_manager.get_pod_ip", return_value="10.1.2.3")
@patch("motor.node_manager.core.engine_manager.load_snapshot_metadata")
@patch("motor.node_manager.core.engine_manager.os.path.exists", return_value=True)
def test_register_prepare_after_restore_refreshes_config(
self, _mock_exists, mock_load, mock_get_pod_ip, engine_manager
):
engine_manager._config.snapshot_config.enable_snapshot = True
engine_manager._config.snapshot_config.snapshot_metadata_path = "/snapshot/snapshot_metadata.json"
engine_manager._config.basic_config.job_name = "old-job"
engine_manager._config.api_config.pod_ip = "10.0.0.1"
mock_controller_config = MagicMock()
mock_controller_config.api_config.controller_api_dns = "controller.old-ns.svc.cluster.local"
ControllerApiClient.controller_config = mock_controller_config
mock_load.side_effect = lambda _path, field: {
"job_name": "restored-job",
"namespace": "new-ns",
}[field]
engine_manager.register_prepare_after_restore()
assert engine_manager._config.basic_config.job_name == "restored-job"
assert engine_manager._config.api_config.pod_ip == "10.1.2.3"
assert (
ControllerApiClient.controller_config.api_config.controller_api_dns == "controller.new-ns.svc.cluster.local"
)
@patch("motor.node_manager.core.engine_manager.update_snapshot_metadata")
@patch("motor.node_manager.core.engine_manager.load_snapshot_metadata")
def test_engine_resume_prepare_updates_missing_fields(
self, mock_load, mock_update, engine_manager, sample_start_cmd_msg, tmp_path
):
from motor.common.utils.snapshot_utils import MOTOR_SNAPSHOT_WEIGHT_DIR
metadata_path = str(tmp_path / "snapshot_metadata.json")
engine_manager._config.snapshot_config.enable_snapshot = True
engine_manager._config.snapshot_config.snapshot_metadata_path = ""
mock_load.side_effect = ValueError("missing field")
with patch.object(engine_manager, "get_snapshot_metadata_path", return_value=metadata_path):
engine_manager.engine_resume_prepare(sample_start_cmd_msg)
mock_update.assert_any_call(metadata_path, "model_load_path", MOTOR_SNAPSHOT_WEIGHT_DIR)
mock_update.assert_any_call(
metadata_path,
"data_parallel_master_ip",
sample_start_cmd_msg.master_dp_ip,
)
def test_is_engine_checkpoint_done_when_snapshot_disabled(self, engine_manager):
engine_manager._config.snapshot_config.enable_snapshot = False
assert engine_manager.is_engine_checkpoint_done() is True
def test_is_engine_checkpoint_done_when_checkpoint_missing(self, engine_manager, tmp_path):
engine_manager._config.snapshot_config.enable_snapshot = True
metadata_path = tmp_path / "snapshot_metadata.json"
metadata_path.write_text('{"model_save_path": "/snapshot/weight"}', encoding="utf-8")
with patch.object(engine_manager, "get_snapshot_metadata_path", return_value=str(metadata_path)):
assert engine_manager.is_engine_checkpoint_done() is False
def test_is_engine_checkpoint_done_when_checkpoint_done(self, engine_manager, tmp_path):
engine_manager._config.snapshot_config.enable_snapshot = True
metadata_path = tmp_path / "snapshot_metadata.json"
metadata_path.write_text('{"checkpoint": "done"}', encoding="utf-8")
with patch.object(engine_manager, "get_snapshot_metadata_path", return_value=str(metadata_path)):
assert engine_manager.is_engine_checkpoint_done() is True