import re
import sys
from unittest import TestCase
from unittest.mock import patch
import pytest
from serving_cast.service.optimizer_summary import SHOW_COLUMNS
from tests.helpers.cli_runner import run_module_main
THROUGHPUT_OPTIMIZER_MODULE = "cli.inference.throughput_optimizer"
AGG_TABLE_TITLE_RE = r"Top\s+\d+\s+(?:PD\s+Aggregated|Aggregation)\s+Configurations\s*:?"
DISAGG_PREFILL_TITLE_RE = (
r"Top\s+\d+\s+(?:PD\s+Disaggregated\s+Prefill|Disaggregation\s+\(Prefill\))\s+Configurations\s*:?"
)
DISAGG_DECODE_TITLE_RE = (
r"Top\s+\d+\s+(?:PD\s+Disaggregated\s+Decode|Disaggregation\s+\(Decode\))\s+Configurations\s*:?"
)
class TestThroughputOptimizer(TestCase):
"""Performance analysis script system test class"""
def test_arg_parse_reserved_memory_default_is_ten(self):
from cli.inference import throughput_optimizer as throughput_optimizer_module
argv = [
"throughput_optimizer",
"--input-length=1",
"--output-length=1",
"Qwen/Qwen3-32B",
]
with patch.object(sys, "argv", argv):
args = throughput_optimizer_module.arg_parse()
self.assertEqual(args.reserved_memory_gb, 10.0)
def _run_throughput_optimizer(self, args, check=True):
"""Run throughput_optimizer's main() in-process so coverage sees the core path."""
result = run_module_main(THROUGHPUT_OPTIMIZER_MODULE, args)
if check and result.returncode != 0:
raise RuntimeError(f"throughput_optimizer failed (rc={result.returncode}): {result.stderr}")
return result
def _validate_table_structure(self, output_text, required_columns, table_start_pattern):
"""Validate the overall table structure and format"""
required_sections = [
"Input Configuration:",
"Overall Best Configuration:",
]
for section in required_sections:
self.assertIsNotNone(
re.search(section, output_text),
f"Required section '{section}' not found in output",
)
header_line = None
for line in output_text.split("\n"):
if all(col in line for col in required_columns):
header_line = line
break
self.assertIsNotNone(header_line, "Table header with required columns not found")
border_pattern = r"\+-+\+"
borders = re.findall(border_pattern, output_text)
self.assertGreaterEqual(len(borders), 2, "Table borders not found or incomplete")
data_row_pattern = r"\|\s*\d+\s*\|.*\|"
data_rows = re.findall(data_row_pattern, output_text)
self.assertGreaterEqual(len(data_rows), 1, "Table data rows not found")
self.assertIsNotNone(
re.search(table_start_pattern, output_text),
"Configurations table title not found",
)
throughput_pattern = r"\|\s*\d+\s*\|[^\|\n]*\d+(?:\.\d+)?[^\|\n]*\|"
throughput_matches = re.findall(throughput_pattern, output_text)
self.assertGreaterEqual(len(throughput_matches), 1, "Throughput values not found in table")
def test_aggregation_functionality_with_output_validation(self):
"""Test aggregation functionality with comprehensive output validation"""
args = [
"--input-length=3500",
"--output-length=1500",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=8",
"--tpot-limits=50",
"--compile",
]
result = self._run_throughput_optimizer(args, check=False)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
required_columns = SHOW_COLUMNS
table_start_pattern = AGG_TABLE_TITLE_RE
self._validate_table_structure(full_output, required_columns, table_start_pattern)
def test_disaggregation_prefill_only_with_output_validation(self):
"""Test disaggregation prefill only functionality with comprehensive output validation"""
args = [
"--input-length=1024",
"--output-length=1024",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=8",
"--ttft-limits=1000",
"--compile",
"--disagg",
]
result = self._run_throughput_optimizer(args, check=False)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = SHOW_COLUMNS.copy()
local_columns.remove("TPOT (ms)")
table_start_pattern = DISAGG_PREFILL_TITLE_RE
self._validate_table_structure(full_output, local_columns, table_start_pattern)
def test_disaggregation_decode_only_with_output_validation(self):
"""Test disaggregation decode only functionality with comprehensive output validation"""
args = [
"--input-length=1024",
"--output-length=1024",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=8",
"--tpot-limits=50",
"--compile",
"--disagg",
"--tp-sizes",
"2",
"4",
"--batch-range",
"1",
"8",
]
result = self._run_throughput_optimizer(args, check=False)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = SHOW_COLUMNS.copy()
local_columns.remove("TTFT (ms)")
table_start_pattern = DISAGG_DECODE_TITLE_RE
self._validate_table_structure(full_output, local_columns, table_start_pattern)
def test_prefix_cache_hit_rate_rejects_invalid_value(self):
args = [
"--input-length=20",
"--output-length=128",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=8",
"--prefix-cache-hit-rate=1.0",
]
result = self._run_throughput_optimizer(args, check=False)
self.assertNotEqual(result.returncode, 0)
self.assertIn("valid range [0, 1)", result.stderr)
def test_prefix_cache_hit_rate_aggregation_valid(self):
args = [
"--input-length=64",
"--output-length=16",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=1",
"--jobs=1",
"--tpot-limits=1000",
"--batch-range",
"1",
"2",
"--prefix-cache-hit-rate=0.5",
]
result = self._run_throughput_optimizer(args, check=False)
self.assertEqual(result.returncode, 0, msg=result.stderr)
def test_prefix_cache_hit_rate_disaggregation_prefill_valid(self):
args = [
"--input-length=64",
"--output-length=16",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=1",
"--jobs=1",
"--ttft-limits=1000",
"--batch-range",
"1",
"2",
"--prefix-cache-hit-rate=0.5",
"--disagg",
]
result = self._run_throughput_optimizer(args, check=False)
self.assertEqual(result.returncode, 0, msg=result.stderr)
def test_prefix_cache_hit_rate_disaggregation_decode_valid(self):
args = [
"--input-length=64",
"--output-length=16",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=1",
"--jobs=1",
"--tpot-limits=1000",
"--batch-range",
"1",
"2",
"--prefix-cache-hit-rate=0.5",
"--disagg",
]
result = self._run_throughput_optimizer(args, check=False)
self.assertEqual(result.returncode, 0, msg=result.stderr)
def test_prefix_cache_hit_rate_allows_chunked_prefill_when_effective_input_exceeds_max_batched_tokens(
self,
):
args = [
"--input-length=200",
"--output-length=16",
"Qwen/Qwen3-32B",
"--device=TEST_DEVICE",
"--num-devices=1",
"--jobs=1",
"--tpot-limits=1000",
"--batch-range",
"1",
"2",
"--prefix-cache-hit-rate=0.5",
"--max-batched-tokens=99",
]
result = self._run_throughput_optimizer(args, check=False)
self.assertEqual(result.returncode, 0, msg=result.stderr)
def test_deepseek_model_pd_ratio_with_output_validation(self):
"""Test deepseek model PD ratio with comprehensive output validation"""
args = [
"--input-length=3500",
"--output-length=1500",
"deepseek-ai/DeepSeek-V3.1",
"--enable-optimize-prefill-decode-ratio",
"--prefill-devices-per-instance=32",
"--decode-devices-per-instance=32",
"--compile",
"--quantize-linear-action=W8A8_DYNAMIC",
"--quantize-attention-action=INT8",
"--device=TEST_DEVICE",
"--jobs=10",
"--ttft-limits=7000",
"--tpot-limits=200",
]
result = self._run_throughput_optimizer(args)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = [
"Top",
"PD Ratio",
"P QPS (req/s)",
"D QPS (req/s)",
"TTFT (ms)",
"TPOT (ms)",
"P Parallel",
"D Parallel",
"P Devices/Instance",
"D Devices/Instance",
"P Batch Size",
"D Batch Size",
"P Concurrency",
"D Concurrency",
]
table_start_pattern = r"\s*Top\s+\d+\s+PD Ratio Configurations:"
self._validate_table_structure(full_output, local_columns, table_start_pattern)
@pytest.mark.nightly
class TestThroughputOptimizerNightly(TestCase):
def _run_throughput_optimizer(self, args, check=True):
return TestThroughputOptimizer._run_throughput_optimizer(self, args, check)
def _validate_table_structure(self, output_text, required_columns, table_start_pattern):
return TestThroughputOptimizer._validate_table_structure(
self, output_text, required_columns, table_start_pattern
)
def test_vl_model_aggregation_with_output_validation(self):
"""Test VL model aggregation functionality with comprehensive output validation"""
args = [
"--input-length=1024",
"--output-length=1024",
"Qwen/Qwen3-VL-30B-A3B-Instruct",
"--device=TEST_DEVICE",
"--num-devices=4",
"--tpot-limits=100",
"--image-height=512",
"--image-width=512",
]
result = self._run_throughput_optimizer(args)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = SHOW_COLUMNS.copy()
table_start_pattern = AGG_TABLE_TITLE_RE
self._validate_table_structure(full_output, local_columns, table_start_pattern)
def test_vl_model_disaggregation_prefill_with_output_validation(self):
"""Test VL model disaggregation prefill only functionality with comprehensive output validation"""
args = [
"--input-length=1024",
"--output-length=1024",
"Qwen/Qwen3-VL-30B-A3B-Instruct",
"--device=TEST_DEVICE",
"--num-devices=8",
"--ttft-limits=2000",
"--image-height=512",
"--image-width=512",
"--disagg",
"--batch-range",
"1",
"8",
]
result = self._run_throughput_optimizer(args)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = SHOW_COLUMNS.copy()
local_columns.remove("TPOT (ms)")
table_start_pattern = DISAGG_PREFILL_TITLE_RE
self._validate_table_structure(full_output, local_columns, table_start_pattern)
def test_vl_model_disaggregation_decode_with_output_validation(self):
"""Test VL model disaggregation decode only functionality with comprehensive output validation"""
args = [
"--input-length=1024",
"--output-length=1024",
"zai-org/GLM-4.5V",
"--device=TEST_DEVICE",
"--num-devices=8",
"--tpot-limits=100",
"--image-height=512",
"--image-width=512",
"--disagg",
]
result = self._run_throughput_optimizer(args)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = SHOW_COLUMNS.copy()
local_columns.remove("TTFT (ms)")
table_start_pattern = DISAGG_DECODE_TITLE_RE
self._validate_table_structure(full_output, local_columns, table_start_pattern)
def test_VL_MOE_model_aggregation_with_output_validation(self):
"""Test VL MOE model aggregation functionality with comprehensive output validation"""
args = [
"--input-length=20",
"--output-length=128",
"Qwen/Qwen3-VL-235B-A22B-Instruct",
"--device=TEST_DEVICE",
"--num-devices=8",
"--image-height=1080",
"--image-width=1920",
"--compile",
"--quantize-linear-action=W8A8_DYNAMIC",
"--quantize-attention-action=INT8",
"--batch-range",
"1",
"4",
"--max-batched-tokens=100",
]
result = self._run_throughput_optimizer(args)
if result.returncode != 0:
self.fail(f"Script execution failed with return code {result.returncode}: {result.stderr}")
full_output = result.stdout + result.stderr
local_columns = SHOW_COLUMNS.copy()
table_start_pattern = AGG_TABLE_TITLE_RE
self._validate_table_structure(full_output, local_columns, table_start_pattern)