import unittest
from unittest.mock import patch
import numpy as np
from mindie_llm.model_wrapper.utils.config import BaseConfig, DmiConfig, LinkMapParams
class TestBaseConfig(unittest.TestCase):
def setUp(self):
self.base_config = {
"local_rank": "0",
"rank": "0",
"world_size": "2",
"npu_device_id": "0",
"npu_device_ids": "0,1",
"cpu_mem": "1024",
"npu_mem": "2048",
"max_seq_len": "512",
"max_iter_times": "100",
"max_prefill_tokens": "1024",
"block_size": "32",
"model_id": "./model",
"distributed_enable": "true",
"globalWorldSize": "2"
}
def test_base_config_init(self):
config = BaseConfig(self.base_config)
self.assertIsNotNone(config.global_world_size)
self.assertEqual(config.global_world_size, 2)
self.assertEqual(config.world_size, 2)
self.assertTrue(config.distributed_enable)
def test_missing_required_field(self):
invalid_config = self.base_config.copy()
invalid_config.pop("model_id")
with self.assertRaises(ValueError):
BaseConfig(invalid_config)
def test_parse_list(self):
config = BaseConfig(self.base_config)
self.assertEqual(config.parse_list("npu_device_ids", to_int=True), [0, 1])
class TestDmiConfig(unittest.TestCase):
def setUp(self):
self.dmi_config = {
"local_rank": "0",
"rank": "0",
"world_size": "2",
"npu_device_id": "0",
"npu_device_ids": "0,1",
"cpu_mem": "1024",
"npu_mem": "2048",
"max_seq_len": "512",
"max_iter_times": "100",
"max_prefill_tokens": "1024",
"block_size": "32",
"model_id": "./model",
"distributed_enable": "true",
"globalWorldSize": "2",
"role": "prefill",
"local_logic_device_id": "0,1",
"local_physical_device_id": "10,11",
"local_device_ip": "192.168.0.1,192.168.0.2",
"local_host_ip": "10.0.0.1,10.0.0.2",
"tp": "1",
"local_instance_id": "1"
}
def test_dmi_config_init_success(self):
config = DmiConfig(self.dmi_config)
self.assertIsNotNone(config.global_world_size)
self.assertEqual(config.global_world_size, 2)
self.assertIsNotNone(config.model_config["world_size"])
self.assertEqual(int(config.model_config["world_size"]), 2)
self.assertEqual(config.role, "prefill")
def test_invalid_role_raises_error(self):
invalid_config = self.dmi_config.copy()
invalid_config["role"] = "invalid_role"
with self.assertRaises(ValueError) as ctx:
DmiConfig(invalid_config)
self.assertIn("The pd_role should be prefill or decoder in DMI mode.", str(ctx.exception))
def test_generate_link_map_matches_actual_logic(self):
params = LinkMapParams(role="prefill", tp_p=4, tp_d=2)
prefill_map = DmiConfig.generate_link_map(params)
self.assertEqual(prefill_map, {0: [0], 2: [1]})
params = LinkMapParams(role="decoder", tp_p=4, tp_d=2)
decoder_map = DmiConfig.generate_link_map(params)
self.assertEqual(decoder_map, {0: [0], 1: [2]})
with self.assertRaises(ValueError):
params = LinkMapParams(role="prefill", tp_p=0, tp_d=2)
DmiConfig.generate_link_map(params)
with self.assertRaises(ValueError):
params = LinkMapParams(role="decoder", tp_p=4, tp_d=0)
DmiConfig.generate_link_map(params)
def test_generate_link_map_for_sp(self):
params = LinkMapParams(role="prefill", tp_p=8, tp_d=1, sp_p=8, sp_d=1, cp_p=2, cp_d=1)
prefill_sp_map = DmiConfig.generate_link_map(params)
self.assertEqual(prefill_sp_map, {0: [0], 1: [0], 2: [0], 3: [0], 4: [0], 5: [0], 6: [0], 7: [0],
8: [0], 9: [0], 10: [0], 11: [0], 12: [0], 13: [0], 14: [0], 15: [0]})
params = LinkMapParams(role="decoder", tp_p=8, tp_d=1, sp_p=8, sp_d=1, cp_p=2, cp_d=1)
decoder_sp_map = DmiConfig.generate_link_map(params)
self.assertEqual(decoder_sp_map, {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]})
def test_clear_remote_info(self):
config = DmiConfig(self.dmi_config)
config.remote_sp_size = 5
config.remote_link_host_ip = {"node1": "192.168.0.100"}
config.clear_remote_info()
self.assertEqual(config.remote_sp_size, 0)
self.assertEqual(config.remote_link_host_ip, {})
def test_set_pd_role(self):
config = DmiConfig(self.dmi_config)
config.set_pd_role(1)
self.assertEqual(config.role, "prefill")
config.set_pd_role(2)
self.assertEqual(config.role, "decoder")
config.set_pd_role(3)
self.assertEqual(config.role, "flex")
config.set_pd_role(4)
self.assertEqual(config.role, "unknown")
@patch("mindie_llm.model_wrapper.utils.config.logger")
def test_set_pd_link_info(self, mock_logger):
config = DmiConfig(self.dmi_config)
mock_requests = [
np.array([[1, 0, 0, 0, 0, 0, 0]], dtype=np.int64),
np.array([[[0, 0, 0, 0, 0, 0, 0]]], dtype=np.int64),
np.array([[0, 1, 1]], dtype=np.int64)
]
config.set_pd_link_info(mock_requests)
mock_logger.info.assert_any_call("[Config]\t>>> start to set PD link/unlink info according to the request.")
self.assertEqual(config.remote_sp_size, 1)
def test_dp_size_logic(self):
dp_config = self.dmi_config.copy()
dp_config["dp"] = "2"
with patch("mindie_llm.model_wrapper.utils.config.generate_dp_inst_id",
return_value=["10", "11"]):
config = DmiConfig(dp_config)
self.assertEqual(config.model_config["local_instance_id"], "10")
def test_set_pd_link_info_with_unlink(self):
full_config = self.dmi_config.copy()
full_config.update({
"sp": "1",
"local_super_device_id": "20,21",
"tp": "1",
"dp": "2",
"local_instance_id": "1000"
})
with patch("mindie_llm.model_wrapper.utils.config.logger"), \
patch("mindie_llm.model_wrapper.utils.config.generate_dp_inst_id",
return_value=["10000", "10001"]):
config = DmiConfig(full_config)
mock_requests = [
np.array([[1, 0, 1, 1, 1, 0, 0]], dtype=np.int64),
np.array([
[[192, 168, 0, 100, -1, -1, -1, -1, 1000], [192, 168, 0, 101, -1, -1, -1, -1, 1001]],
[[192, 168, 0, 200, -1, -1, -1, -1, 2000], [192, 168, 0, 201, -1, -1, -1, -1, 2001]]
], dtype=np.int64),
np.array([[0, 1, 1]], dtype=np.int64)
]
config.set_pd_link_info(mock_requests)
self.assertIn(0, config.remote_unlink_cluster_id)
self.assertIn(2000, config.remote_unlink_cluster_id[0])
self.assertIn(0, config.remote_unlink_device_ips)
self.assertEqual(
config.remote_unlink_device_ips[0],
["192.168.0.201"]
)
if __name__ == "__main__":
unittest.main()