import os
import sys
import tempfile
import unittest
from serving_cast.config import LoadGenConfig
from serving_cast.main import get_load_gen, get_serving, instance_group2pd_type, parse_command_line_args
class TestParseCommandLineArgs(unittest.TestCase):
"""Tests for parse_command_line_args function."""
def setUp(self):
"""Set up test fixtures with temporary config files."""
self.temp_dir = tempfile.mkdtemp()
self.instance_config_path = os.path.join(self.temp_dir, "instance.yaml")
self.common_config_path = os.path.join(self.temp_dir, "common.yaml")
with open(self.instance_config_path, "w", encoding="utf-8") as f:
f.write("instance_groups: []")
with open(self.common_config_path, "w", encoding="utf-8") as f:
f.write(
"model_config: {name: test}\nload_gen: "
"{load_gen_type: fixed_length, num_requests: 100, num_input_tokens: 128, "
"num_output_tokens: 128, request_rate: 1.0}\n"
"serving_config: {}"
)
def tearDown(self):
"""Clean up temporary files."""
import shutil
shutil.rmtree(self.temp_dir)
def test_parse_args_valid_required_args(self):
"""Test parsing with valid required arguments."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
"--common_config_path",
self.common_config_path,
]
args = parse_command_line_args()
self.assertEqual(args.instance_config_path, self.instance_config_path)
self.assertEqual(args.common_config_path, self.common_config_path)
self.assertFalse(args.enable_profiling)
self.assertEqual(args.profiling_output_path, "./profiling_results")
finally:
sys.argv = original_argv
def test_parse_args_with_profiling_enabled(self):
"""Test parsing with profiling enabled."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
"--common_config_path",
self.common_config_path,
"--enable_profiling",
]
args = parse_command_line_args()
self.assertTrue(args.enable_profiling)
finally:
sys.argv = original_argv
def test_parse_args_with_custom_profiling_output_path(self):
"""Test parsing with custom profiling output path."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
"--common_config_path",
self.common_config_path,
"--profiling_output_path",
"/custom/path",
]
args = parse_command_line_args()
self.assertEqual(args.profiling_output_path, "/custom/path")
finally:
sys.argv = original_argv
def test_parse_args_output_json_default_none(self):
"""Test that --output_json defaults to None when not provided."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
"--common_config_path",
self.common_config_path,
]
args = parse_command_line_args()
self.assertIsNone(args.output_json)
finally:
sys.argv = original_argv
def test_parse_args_with_output_json(self):
"""Test parsing with --output_json provided."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
"--common_config_path",
self.common_config_path,
"--output_json",
"/tmp/summary.json",
]
args = parse_command_line_args()
self.assertEqual(args.output_json, "/tmp/summary.json")
finally:
sys.argv = original_argv
def test_parse_args_missing_instance_config(self):
"""Test parsing with missing instance config path."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--common_config_path",
self.common_config_path,
]
with self.assertRaises(SystemExit):
parse_command_line_args()
finally:
sys.argv = original_argv
def test_parse_args_missing_common_config(self):
"""Test parsing with missing common config path."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
]
with self.assertRaises(SystemExit):
parse_command_line_args()
finally:
sys.argv = original_argv
def test_parse_args_nonexistent_instance_config(self):
"""Test parsing with non-existent instance config file."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
"/nonexistent/path.yaml",
"--common_config_path",
self.common_config_path,
]
with self.assertRaises(SystemExit):
parse_command_line_args()
finally:
sys.argv = original_argv
def test_parse_args_nonexistent_common_config(self):
"""Test parsing with non-existent common config file."""
original_argv = sys.argv
try:
sys.argv = [
"main.py",
"--instance_config_path",
self.instance_config_path,
"--common_config_path",
"/nonexistent/path.yaml",
]
with self.assertRaises(SystemExit):
parse_command_line_args()
finally:
sys.argv = original_argv
class TestInstanceGroup2PdType(unittest.TestCase):
"""Tests for instance_group2pd_type function."""
def test_pd_aggregation(self):
"""Test pd_aggregation detection."""
instance_group = {"both": [1, 2], "prefill": [], "decode": []}
result = instance_group2pd_type(instance_group)
self.assertEqual(result, "pd_aggregation")
def test_pd_disaggregation(self):
"""Test pd_disaggregation detection."""
instance_group = {"both": [], "prefill": [1], "decode": [2]}
result = instance_group2pd_type(instance_group)
self.assertEqual(result, "pd_disaggregation")
def test_invalid_both_empty(self):
"""Test invalid when all empty."""
instance_group = {"both": [], "prefill": [], "decode": []}
result = instance_group2pd_type(instance_group)
self.assertIsNone(result)
def test_invalid_both_non_empty(self):
"""Test invalid when both both and prefill non-empty."""
instance_group = {"both": [1], "prefill": [2], "decode": []}
result = instance_group2pd_type(instance_group)
self.assertIsNone(result)
class TestGetServing(unittest.TestCase):
"""Tests for get_serving function."""
def test_get_serving_invalid(self):
"""Test get_serving raises error for invalid config."""
instance_group = {"both": [], "prefill": [], "decode": []}
with self.assertRaises(ValueError) as context:
get_serving(instance_group)
self.assertIn("Unknown pd type", str(context.exception))
class TestGetLoadGen(unittest.TestCase):
"""Tests for get_load_gen function."""
def test_get_load_gen_fixed_length(self):
"""Test get_load_gen for fixed_length type."""
config = LoadGenConfig(
load_gen_type="fixed_length",
num_requests=100,
num_input_tokens=128,
num_output_tokens=128,
request_rate=1.0,
)
result = get_load_gen(config)
self.assertIsNotNone(result)
self.assertEqual(result.num_requests, 100)
self.assertEqual(result.request_rate, 1.0)
def test_get_load_gen_unknown_type(self):
"""Test get_load_gen raises error for unknown type."""
config = LoadGenConfig(
load_gen_type="fixed_length",
num_requests=100,
num_input_tokens=128,
num_output_tokens=128,
request_rate=1.0,
)
config.load_gen_type = "unknown"
with self.assertRaises((ValueError, AttributeError)):
get_load_gen(config)
if __name__ == "__main__":
unittest.main()