import os
import tempfile
import unittest
import yaml
from serving_cast.config import (
CommonConfig,
CommunicationConfig,
Config,
InstanceConfig,
LoadGenConfig,
ModelConfig,
ParallelConfig,
ServingConfig,
)
class TestParallelConfig(unittest.TestCase):
def test_custom_values(self):
"""Test ParallelConfig with custom values."""
config = ParallelConfig(
world_size=8,
tp_size=4,
dp_size=2,
ep_size=8,
mlp_tp_size=2,
mlp_dp_size=2,
)
self.assertEqual(config.world_size, 8)
self.assertEqual(config.tp_size, 4)
self.assertEqual(config.dp_size, 2)
self.assertEqual(config.ep_size, 8)
self.assertEqual(config.mlp_tp_size, 2)
self.assertEqual(config.mlp_dp_size, 2)
class TestCommunicationConfig(unittest.TestCase):
def test_default_values(self):
"""Test CommunicationConfig default values."""
config = CommunicationConfig()
self.assertEqual(config.host2device_bandwidth, 1e10)
self.assertEqual(config.host2device_rate, 0.5)
self.assertEqual(config.device2device_bandwidth, 4e9)
self.assertEqual(config.device2device_rate, 0.5)
def test_custom_values(self):
"""Test CommunicationConfig with custom values."""
config = CommunicationConfig(
host2device_bandwidth=1e9,
host2device_rate=0.8,
device2device_bandwidth=1e10,
device2device_rate=0.9,
)
self.assertEqual(config.host2device_bandwidth, 1e9)
self.assertEqual(config.host2device_rate, 0.8)
self.assertEqual(config.device2device_bandwidth, 1e10)
self.assertEqual(config.device2device_rate, 0.9)
class TestInstanceConfig(unittest.TestCase):
def test_instance_config_creation(self):
"""Test InstanceConfig creation."""
config = InstanceConfig(
num_instances=4,
num_devices_per_instance=8,
pd_role="prefill",
parallel_config=ParallelConfig(),
communication_config=CommunicationConfig(),
)
self.assertEqual(config.num_instances, 4)
self.assertEqual(config.num_devices_per_instance, 8)
self.assertEqual(config.pd_role, "prefill")
self.assertEqual(config.device_type, "TEST_DEVICE")
class TestLoadGenConfig(unittest.TestCase):
def test_load_gen_config_creation(self):
"""Test LoadGenConfig creation."""
config = LoadGenConfig(
load_gen_type="fixed_length",
num_requests=100,
num_input_tokens=1000,
num_output_tokens=100,
request_rate=1.0,
)
self.assertEqual(config.load_gen_type, "fixed_length")
self.assertEqual(config.num_requests, 100)
self.assertEqual(config.num_input_tokens, 1000)
self.assertEqual(config.num_output_tokens, 100)
self.assertEqual(config.request_rate, 1.0)
class TestServingConfig(unittest.TestCase):
def test_default_values(self):
"""Test ServingConfig default values."""
config = ServingConfig()
self.assertEqual(config.max_concurrency, 100)
self.assertEqual(config.block_size, 128)
self.assertEqual(config.max_tokens_budget, 8192)
def test_custom_values(self):
"""Test ServingConfig with custom values."""
config = ServingConfig(
max_concurrency=200,
block_size=256,
max_tokens_budget=16384,
)
self.assertEqual(config.max_concurrency, 200)
self.assertEqual(config.block_size, 256)
self.assertEqual(config.max_tokens_budget, 16384)
class TestModelConfig(unittest.TestCase):
def test_model_config_creation(self):
"""Test ModelConfig creation."""
config = ModelConfig(name="test-model")
self.assertEqual(config.name, "test-model")
self.assertEqual(config.num_mtp_tokens, 0)
self.assertFalse(config.do_compile)
self.assertFalse(config.allow_graph_break)
self.assertFalse(config.dump_input_shapes)
self.assertEqual(config.quantize_linear_action, "W8A8_DYNAMIC")
self.assertFalse(config.quantize_lmhead)
self.assertEqual(config.mxfp4_group_size, 32)
self.assertEqual(config.quantize_attention_action, "DISABLED")
def test_model_config_custom_values(self):
"""Test ModelConfig with custom values."""
config = ModelConfig(
name="custom-model",
num_mtp_tokens=4,
do_compile=True,
allow_graph_break=True,
quantize_linear_action="FP8",
quantize_lmhead=True,
enable_multi_process=True,
num_processes=8,
predict_steps=10,
enable_interpolate=False,
interpolation_seed=42,
)
self.assertEqual(config.name, "custom-model")
self.assertEqual(config.num_mtp_tokens, 4)
self.assertTrue(config.do_compile)
self.assertTrue(config.allow_graph_break)
self.assertEqual(config.quantize_linear_action, "FP8")
self.assertTrue(config.quantize_lmhead)
self.assertTrue(config.enable_multi_process)
self.assertEqual(config.num_processes, 8)
self.assertEqual(config.predict_steps, 10)
self.assertFalse(config.enable_interpolate)
self.assertEqual(config.interpolation_seed, 42)
class TestCommonConfig(unittest.TestCase):
def test_common_config_creation(self):
"""Test CommonConfig creation."""
model_config = ModelConfig(name="test-model")
load_gen_config = LoadGenConfig(
load_gen_type="fixed_length",
num_requests=100,
num_input_tokens=1000,
num_output_tokens=100,
request_rate=1.0,
)
serving_config = ServingConfig()
config = CommonConfig(
model_config=model_config,
load_gen=load_gen_config,
serving_config=serving_config,
)
self.assertEqual(config.model_config, model_config)
self.assertEqual(config.load_gen, load_gen_config)
self.assertEqual(config.serving_config, serving_config)
class TestConfig(unittest.TestCase):
def setUp(self):
"""Reset Config singleton before each test."""
Config._instance = None
Config._initialized = False
def test_config_get_instance_not_initialized(self):
"""Test that get_instance raises error when not initialized."""
with self.assertRaises(ValueError):
Config.get_instance()
def test_config_singleton(self):
"""Test that Config is a singleton."""
with tempfile.TemporaryDirectory() as tmpdir:
instance_config = {
"instance_groups": [
{
"num_instances": 1,
"num_devices_per_instance": 4,
"pd_role": "both",
}
]
}
instance_path = os.path.join(tmpdir, "instance.yaml")
with open(instance_path, "w", encoding="utf-8") as f:
yaml.dump(instance_config, f)
common_config = {
"model_config": {"name": "test-model"},
"load_gen": {
"load_gen_type": "fixed_length",
"num_requests": 10,
"num_input_tokens": 100,
"num_output_tokens": 10,
"request_rate": 1.0,
},
}
common_path = os.path.join(tmpdir, "common.yaml")
with open(common_path, "w", encoding="utf-8") as f:
yaml.dump(common_config, f)
class ParsedArgs:
instance_config_path = instance_path
common_config_path = common_path
enable_profiling = False
config1 = Config(ParsedArgs())
config2 = Config(ParsedArgs())
self.assertIs(config1, config2)
def test_config_get_instance_after_init(self):
"""Test get_instance returns config after initialization."""
with tempfile.TemporaryDirectory() as tmpdir:
instance_config = {
"instance_groups": [
{
"num_instances": 1,
"num_devices_per_instance": 4,
"pd_role": "both",
}
]
}
instance_path = os.path.join(tmpdir, "instance.yaml")
with open(instance_path, "w", encoding="utf-8") as f:
yaml.dump(instance_config, f)
common_config = {
"model_config": {"name": "test-model"},
"load_gen": {
"load_gen_type": "fixed_length",
"num_requests": 10,
"num_input_tokens": 100,
"num_output_tokens": 10,
"request_rate": 1.0,
},
}
common_path = os.path.join(tmpdir, "common.yaml")
with open(common_path, "w", encoding="utf-8") as f:
yaml.dump(common_config, f)
class ParsedArgs:
instance_config_path = instance_path
common_config_path = common_path
enable_profiling = False
config = Config(ParsedArgs())
self.assertEqual(Config.get_instance(), config)
if __name__ == "__main__":
unittest.main()