import unittest
from unittest.mock import patch
from mindie_llm.runtime.config.load_config import LoadConfig
class TestLoadConfig(unittest.TestCase):
def test_post_init_tokenizer_path_default(self):
config = LoadConfig(model_name_or_path="/path/to/model")
self.assertEqual(config.tokenizer_path, "/path/to/model")
def test_post_init_tokenizer_path_custom(self):
config = LoadConfig(model_name_or_path="/path/to/model", tokenizer_path="/custom/tokenizer")
self.assertEqual(config.tokenizer_path, "/custom/tokenizer")
def test_from_dict_creates_instance(self):
config_dict = {
"model_name_or_path": "/path/to/model",
"max_position_embeddings": 4096,
"trust_remote_code": True,
"load_tokenizer": False,
"tokenizer_path": "/custom/tokenizer",
"llm_config_path": "/config.json",
"models_dict": {"key": "value"}
}
config = LoadConfig.from_dict(config_dict)
self.assertEqual(config.model_name_or_path, "/path/to/model")
self.assertEqual(config.max_position_embeddings, 4096)
self.assertTrue(config.trust_remote_code)
self.assertFalse(config.load_tokenizer)
self.assertEqual(config.tokenizer_path, "/custom/tokenizer")
self.assertEqual(config.llm_config_path, "/config.json")
self.assertEqual(config.models_dict, {"key": "value"})
def test_from_dict_ignores_invalid_fields(self):
config_dict = {
"model_name_or_path": "/path/to/model",
"invalid_field": "should_be_ignored",
"max_position_embeddings": 4096
}
config = LoadConfig.from_dict(config_dict)
self.assertEqual(config.model_name_or_path, "/path/to/model")
self.assertEqual(config.max_position_embeddings, 4096)
self.assertFalse(hasattr(config, "invalid_field"))
def test_validate_max_position_embeddings_valid(self):
config = LoadConfig(model_name_or_path="model", max_position_embeddings=4096)
config.validate()
def test_validate_max_position_embeddings_invalid(self):
config = LoadConfig(model_name_or_path="model", max_position_embeddings=-1)
with self.assertRaises(ValueError) as cm:
config.validate()
self.assertIn("max_position_embeddings must be a positive integer", str(cm.exception))
def test_validate_max_position_embeddings_zero(self):
config = LoadConfig(model_name_or_path="model", max_position_embeddings=0)
with self.assertRaises(ValueError) as cm:
config.validate()
self.assertIn("max_position_embeddings must be a positive integer", str(cm.exception))
def test_validate_models_dict_valid(self):
config = LoadConfig(model_name_or_path="model", models_dict={str(i): i for i in range(4096)})
config.validate()
def test_validate_models_dict_invalid_too_long(self):
config = LoadConfig(model_name_or_path="model", models_dict={str(i): i for i in range(4097)})
with self.assertRaises(ValueError) as cm:
config.validate()
self.assertIn("The length of plugin_params (4097) is too long", str(cm.exception))
def test_validate_models_dict_none(self):
config = LoadConfig(model_name_or_path="model", models_dict=None)
config.validate()
def test_validate_models_dict_empty(self):
config = LoadConfig(model_name_or_path="model", models_dict={})
config.validate()
if __name__ == '__main__':
unittest.main()