import json
import os
import tempfile
import unittest
from dataclasses import dataclass
from typing import Optional
from serving_cast.request import Request, RequestState
from serving_cast.utils import dataclass2dict, gen_profiling_config_set_env_variable, get_basic_timestamp, summarize
@dataclass
class NestedDataclass:
value: int
@dataclass
class SampleDataclass:
name: str
count: int
nested: Optional[NestedDataclass] = None
@dataclass
class ComplexDataclass:
items: list[int]
nested_list: list[NestedDataclass]
dict_field: dict
class TestDataclass2Dict(unittest.TestCase):
def test_simple_dataclass(self):
"""Test converting a simple dataclass to dict."""
obj = SampleDataclass(name="test", count=42)
result = dataclass2dict(obj)
self.assertEqual(result["name"], "test")
self.assertEqual(result["count"], 42)
self.assertIsNone(result["nested"])
def test_nested_dataclass(self):
"""Test converting a dataclass with nested dataclass."""
nested = NestedDataclass(value=100)
obj = SampleDataclass(name="test", count=42, nested=nested)
result = dataclass2dict(obj)
self.assertEqual(result["name"], "test")
self.assertEqual(result["count"], 42)
self.assertEqual(result["nested"]["value"], 100)
def test_skip_none_true(self):
"""Test skip_none=True removes None fields."""
obj = SampleDataclass(name="test", count=42, nested=None)
result = dataclass2dict(obj, skip_none=True)
self.assertEqual(result["name"], "test")
self.assertEqual(result["count"], 42)
self.assertNotIn("nested", result)
def test_skip_none_false(self):
"""Test skip_none=False keeps None fields."""
obj = SampleDataclass(name="test", count=42, nested=None)
result = dataclass2dict(obj, skip_none=False)
self.assertEqual(result["name"], "test")
self.assertEqual(result["count"], 42)
self.assertIn("nested", result)
self.assertIsNone(result["nested"])
def test_list_of_dataclasses(self):
"""Test converting dataclass with list of dataclasses."""
obj = ComplexDataclass(
items=[1, 2, 3],
nested_list=[NestedDataclass(value=i) for i in range(3)],
dict_field={"key": "value"},
)
result = dataclass2dict(obj)
self.assertEqual(result["items"], [1, 2, 3])
self.assertEqual(len(result["nested_list"]), 3)
for i, item in enumerate(result["nested_list"]):
self.assertEqual(item["value"], i)
self.assertEqual(result["dict_field"], {"key": "value"})
def test_non_dataclass_raises_error(self):
"""Test that non-dataclass raises TypeError."""
with self.assertRaises(TypeError):
dataclass2dict({"not": "a dataclass"})
with self.assertRaises(TypeError):
dataclass2dict([1, 2, 3])
with self.assertRaises(TypeError):
dataclass2dict("string")
class TestGetBasicTimestamp(unittest.TestCase):
def test_timestamp_format(self):
"""Test that timestamp has correct format."""
timestamp = get_basic_timestamp()
parts = timestamp.split("_")
self.assertEqual(len(parts), 2)
date_part, time_part = parts
self.assertEqual(len(date_part.split("-")), 3)
self.assertEqual(len(time_part.split("-")), 3)
def test_timestamp_is_string(self):
"""Test that timestamp is a string."""
timestamp = get_basic_timestamp()
self.assertIsInstance(timestamp, str)
def test_timestamp_not_empty(self):
"""Test that timestamp is not empty."""
timestamp = get_basic_timestamp()
self.assertTrue(len(timestamp) > 0)
class TestGenProfilingConfigSetEnvVariable(unittest.TestCase):
def test_creates_config_file(self):
"""Test that config file is created."""
with tempfile.TemporaryDirectory() as tmpdir:
gen_profiling_config_set_env_variable(tmpdir)
config_path = os.path.join(tmpdir, "profiling_config.json")
self.assertTrue(os.path.exists(config_path))
def test_config_content(self):
"""Test that config file has correct content."""
with tempfile.TemporaryDirectory() as tmpdir:
gen_profiling_config_set_env_variable(tmpdir)
config_path = os.path.join(tmpdir, "profiling_config.json")
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
self.assertEqual(config["enable"], 1)
self.assertEqual(config["prof_dir"], tmpdir)
self.assertEqual(config["profiler_level"], "INFO")
def test_env_variable_set(self):
"""Test that environment variable is set."""
with tempfile.TemporaryDirectory() as tmpdir:
gen_profiling_config_set_env_variable(tmpdir)
config_path = os.path.join(tmpdir, "profiling_config.json")
self.assertEqual(os.environ.get("SERVICE_PROF_CONFIG_PATH"), config_path)
class TestSummarize(unittest.TestCase):
def test_summarize_basic(self):
"""Test summarize with basic request data."""
request1 = Request(num_input_tokens=100, num_output_tokens=50)
request1.leaves_client_time = 0.0
request1.arrives_server_time = 0.1
request1.prefill_done_time = 1.0
request1.decode_done_time = 5.0
request1._state = RequestState.DECODE_DONE
request2 = Request(num_input_tokens=200, num_output_tokens=100)
request2.leaves_client_time = 0.5
request2.arrives_server_time = 0.6
request2.prefill_done_time = 2.0
request2.decode_done_time = 10.0
request2._state = RequestState.DECODE_DONE
summarize([request1, request2])
def test_summarize_single_request(self):
"""Test summarize with single request."""
request = Request(num_input_tokens=100, num_output_tokens=10)
request.leaves_client_time = 0.0
request.arrives_server_time = 0.0
request.prefill_done_time = 1.0
request.decode_done_time = 2.0
request._state = RequestState.DECODE_DONE
summarize([request])
def test_summarize_writes_output_json(self):
"""summarize() writes structured per-metric and overall summaries to JSON."""
request1 = Request(num_input_tokens=100, num_output_tokens=50)
request1.leaves_client_time = 0.0
request1.arrives_server_time = 0.1
request1.prefill_done_time = 1.0
request1.decode_done_time = 5.0
request1._state = RequestState.DECODE_DONE
request2 = Request(num_input_tokens=200, num_output_tokens=100)
request2.leaves_client_time = 0.5
request2.arrives_server_time = 0.6
request2.prefill_done_time = 2.0
request2.decode_done_time = 10.0
request2._state = RequestState.DECODE_DONE
with tempfile.TemporaryDirectory() as tmpdir:
json_path = os.path.join(tmpdir, "nested", "summary.json")
summarize([request1, request2], output_json_path=json_path)
self.assertTrue(os.path.exists(json_path))
with open(json_path, encoding="utf-8") as f:
payload = json.load(f)
self.assertIn("per_metric_summary", payload)
self.assertIn("overall_summary", payload)
per_metric = payload["per_metric_summary"]
for column in (
"E2E_TIME(s)",
"TTFT(s)",
"TPOT(s)",
"INPUT_TOKENS",
"OUTPUT_TOKENS",
"OUTPUT_TOKEN_THROUGHPUT(tok/s)",
):
self.assertIn(column, per_metric)
for row in ("AVERAGE", "MIN", "MAX", "MEDIAN", "P75", "P90", "P99"):
self.assertIn(row, per_metric[column])
self.assertIsInstance(per_metric[column][row], float)
overall = payload["overall_summary"]
for key in (
"benchmark_duration(s)",
"total_requests",
"request_throughput(req/s)",
"total_input_tokens",
"input_token_throughput(tok/s)",
"total_output_tokens",
"output_token_throughput(tok/s)",
):
self.assertIn(key, overall)
self.assertIsInstance(overall[key], float)
self.assertEqual(overall["total_requests"], 2.0)
if __name__ == "__main__":
unittest.main()