import os
import json
import tempfile
import unittest
from mindie_llm.runtime.config.configuration_utils import LLMConfig
from mindie_llm.runtime.utils.helpers.parameter_validators import (
DictionaryParameterValidator, BooleanParameterValidator, RangeParamaterValidator
)
class TestConfiguratioinUtils(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.empty_config_path = os.path.join(self.temp_dir.name, 'empty_config.json')
with open(self.empty_config_path, 'w') as f:
json.dump({}, f)
self.config_path1 = os.path.join(self.temp_dir.name, 'test_config1.json')
self.test_json_config1 = {
"llm": {
"ccl": {
'backend': 'lccl',
"enable_mc2": "false"
},
"engine": {
"graph": "cpp"
},
"plugins": {
"plugin_type": "mtp",
"num_speculative_tokens": "1"
},
"enable_reasoning": "false"
},
}
with open(self.config_path1, 'w') as f:
json.dump(self.test_json_config1, f)
self.config_path2 = os.path.join(self.temp_dir.name, 'test_config2.json')
self.test_json_config2 = {
"llm": {
"ccl": {
'backend': 'lccl',
"enable_mc2": "false"
},
"engine": {
"graph": "cpp"
},
"enable_reasoning": "false"
},
"models": {
"deepseekv2": {
"ccl": {
'backend': 'hccl',
"enable_mc2": "true"
},
"enable_reasoning": "true"
}
},
}
with open(self.config_path2, 'w') as f:
json.dump(self.test_json_config2, f)
self.validators = {"llm": DictionaryParameterValidator({
"ccl": DictionaryParameterValidator({
"backend": RangeParamaterValidator(range_list=["lccl", "hccl"]),
"enable_mc2": BooleanParameterValidator()
}),
"engine": DictionaryParameterValidator({
"graph": RangeParamaterValidator(range_list=["cpp", "python"])
})
}), "models": {}}
def tearDown(self):
self.temp_dir.cleanup()
def test_default_config(self):
llm_config = LLMConfig(self.empty_config_path)
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
def test_load_config(self):
llm_config = LLMConfig(self.config_path1)
self.assertEqual(llm_config.llm.ccl.backend, 'lccl')
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
def test_update_dict_replace(self):
llm_config = LLMConfig(self.config_path1)
replace_dict = {
"llm": {
"engine": {
"graph": "python"
},
"enable_reasoning": "true"
}
}
llm_config.update(replace_dict, allow_new_keys=False)
self.assertIn("LLMConfig", llm_config.__repr__())
self.assertEqual(llm_config.llm.engine.graph, 'python')
self.assertEqual(llm_config.llm.ccl.backend, 'lccl')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
self.assertTrue(llm_config.llm.enable_reasoning)
def test_update_dict_add_llm(self):
llm_config = LLMConfig(self.config_path1)
add_dict = {
"llm": {
"stream_options": {
"cv_dual": "false",
"micro_batch": "false"
},
}
}
llm_config.update(add_dict, allow_new_keys=True)
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
self.assertEqual(llm_config.llm.ccl.backend, 'lccl')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
self.assertFalse(llm_config.llm.stream_options.cv_dual)
def test_update_dict_add_models(self):
llm_config = LLMConfig(self.config_path1)
add_models_dict = {
"models": {
"deepseekv2": {
"eplb": {
"level": 0,
"expert_map_file": "",
"num_of_redundant_experts": 0
},
"ep_level": 1,
}
}
}
llm_config.update(add_models_dict, allow_new_keys=True)
self.assertEqual(llm_config.llm.ccl.backend, 'lccl')
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
self.assertEqual(llm_config.models.deepseekv2.eplb.level, 0)
self.assertEqual(llm_config.models.deepseekv2.eplb.expert_map_file, None)
self.assertEqual(llm_config.models.deepseekv2.eplb.num_of_redundant_experts, 0)
self.assertEqual(llm_config.models.deepseekv2.ep_level, 1)
def test_update_dict_add_none(self):
llm_config = LLMConfig(self.config_path1)
add_models_dict = None
llm_config.update(add_models_dict, allow_new_keys=True, current_path='models')
self.assertEqual(llm_config.llm.ccl.backend, 'lccl')
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
self.assertEqual(llm_config.llm.plugins.plugin_type, 'mtp')
def test_merge_correct_model_config(self):
llm_config = LLMConfig(self.config_path2)
model_name = 'deepseekv2'
llm_config.merge_models_config(model_name)
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
self.assertEqual(llm_config.llm.ccl.backend, 'hccl')
self.assertTrue(llm_config.llm.enable_reasoning)
def test_merge_incorrect_model_config(self):
llm_config = LLMConfig(self.config_path2)
model_name = 'error_name'
llm_config.merge_models_config(model_name)
self.assertEqual(llm_config.llm.engine.graph, 'cpp')
self.assertEqual(llm_config.llm.ccl.backend, 'lccl')
def test_check_config(self):
llm_config = LLMConfig(self.config_path2)
llm_config.check_config(self.validators)
if __name__ == "__main__":
unittest.main()