"""Unit tests for optix.run_throughput_optimizer_cases."""
import csv
import os
import shutil
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
experimental_dir = str(Path(__file__).resolve().parents[1])
if experimental_dir not in sys.path:
sys.path.insert(0, experimental_dir)
project_root = str(Path(__file__).resolve().parents[2])
if project_root not in sys.path:
sys.path.insert(0, project_root)
from optix.run_throughput_optimizer_cases import (
DEFAULT_TPOT_LIMIT_MS,
FLUSH_BATCH_SIZE,
LOG_LEVELS,
BenchmarkCase,
BenchmarkResult,
CSV_CONFIG_HEADER,
_build_optimizer_args,
_configure_logging,
_csv_header_and_ref_row,
_filter_best_row,
_parse_args,
_parse_bool,
_parse_list_float,
_parse_list_int,
_parse_mode,
_parse_optional_bool,
_parse_parallel,
_result_row,
_safe_float,
_single_limit,
load_cases_from_csv,
save_results_to_csv,
write_template_csv,
)
from tensor_cast.core.quantization.datatypes import (
QuantizeLinearAction,
QuantizeAttentionAction,
)
class TestParseListFloat(unittest.TestCase):
def test_normal_semicolon_separated(self):
self.assertEqual(_parse_list_float("1.0;2.0;3.0"), [1.0, 2.0, 3.0])
def test_single_value(self):
self.assertEqual(_parse_list_float("50.0"), [50.0])
def test_empty_string(self):
self.assertEqual(_parse_list_float(""), [])
def test_none(self):
self.assertEqual(_parse_list_float(None), [])
def test_whitespace_only(self):
self.assertEqual(_parse_list_float(" "), [])
def test_values_with_spaces(self):
self.assertEqual(_parse_list_float(" 1.0 ; 2.0 "), [1.0, 2.0])
class TestParseListInt(unittest.TestCase):
def test_normal(self):
self.assertEqual(_parse_list_int("1;2;4"), [1, 2, 4])
def test_single_value(self):
self.assertEqual(_parse_list_int("8"), [8])
def test_empty_returns_none(self):
self.assertIsNone(_parse_list_int(""))
def test_none(self):
self.assertIsNone(_parse_list_int(None))
class TestParseBool(unittest.TestCase):
def test_true_variants(self):
for v in ("true", "True", "1", "yes", "YES"):
self.assertTrue(_parse_bool(v), f"Expected True for '{v}'")
def test_false_variants(self):
for v in ("false", "0", "no", "random"):
self.assertFalse(_parse_bool(v), f"Expected False for '{v}'")
def test_none(self):
self.assertFalse(_parse_bool(None))
def test_empty(self):
self.assertFalse(_parse_bool(""))
class TestParseOptionalBool(unittest.TestCase):
def test_true(self):
self.assertTrue(_parse_optional_bool("true"))
def test_false(self):
self.assertFalse(_parse_optional_bool("false"))
def test_empty_returns_none(self):
self.assertIsNone(_parse_optional_bool(""))
def test_none_returns_none(self):
self.assertIsNone(_parse_optional_bool(None))
def test_invalid_returns_none(self):
self.assertIsNone(_parse_optional_bool("maybe"))
class TestParseMode(unittest.TestCase):
def test_agg(self):
self.assertEqual(_parse_mode("agg"), "agg")
def test_disagg(self):
self.assertEqual(_parse_mode("disagg"), "disagg")
def test_default_empty(self):
self.assertEqual(_parse_mode(""), "agg")
def test_default_none(self):
self.assertEqual(_parse_mode(None), "agg")
def test_invalid_falls_back_to_agg(self):
self.assertEqual(_parse_mode("invalid"), "agg")
class TestParseParallel(unittest.TestCase):
def test_valid(self):
self.assertEqual(_parse_parallel("tp1pp1dp1"), (1, 1, 1))
def test_multi_digit(self):
self.assertEqual(_parse_parallel("tp2pp3dp4"), (2, 3, 4))
def test_empty(self):
self.assertEqual(_parse_parallel(""), (None, None, None))
def test_invalid_format(self):
self.assertEqual(_parse_parallel("abc"), (None, None, None))
def test_none(self):
self.assertEqual(_parse_parallel(None), (None, None, None))
def test_verbose_format(self):
self.assertEqual(_parse_parallel("TP=4 | PP=1 | DP=1"), (4, 1, 1))
def test_verbose_format_lower(self):
self.assertEqual(_parse_parallel("tp=2 | pp=3 | dp=4"), (2, 3, 4))
def test_verbose_partial(self):
self.assertEqual(_parse_parallel("TP=4 | DP=1"), (4, None, 1))
class TestSingleLimit(unittest.TestCase):
def test_empty_returns_none(self):
self.assertIsNone(_single_limit([], "test_field"))
def test_single_value_returns_value(self):
self.assertEqual(_single_limit([50.0], "test_field"), 50.0)
def test_multiple_raises_value_error(self):
with self.assertRaises(ValueError) as ctx:
_single_limit([1.0, 2.0], "ttft_limits")
self.assertIn("ttft_limits", str(ctx.exception))
self.assertIn("2", str(ctx.exception))
def test_error_message_contains_name_and_count(self):
with self.assertRaises(ValueError) as ctx:
_single_limit([1.0, 2.0, 3.0], "my_limits")
msg = str(ctx.exception)
self.assertIn("my_limits", msg)
self.assertIn("3", msg)
class TestLoadCasesFromCsv(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir, ignore_errors=True)
def _write_csv(self, rows, header=None):
path = os.path.join(self.tmpdir, "cases.csv")
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if header is None:
header = CSV_CONFIG_HEADER
writer.writerow(header)
for row in rows:
writer.writerow(row)
return path
def test_basic_load(self):
path = self._write_csv(
[
[
"test_case",
"TEST_DEVICE",
"8",
"Qwen/Qwen3-32B",
"3500",
"1500",
"2000",
"50",
"",
"W8A8_DYNAMIC",
"DISABLED",
"",
"0",
"",
"true",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
]
)
cases = load_cases_from_csv(path)
self.assertEqual(len(cases), 1)
c = cases[0]
self.assertEqual(c.case_name, "test_case")
self.assertEqual(c.device, "TEST_DEVICE")
self.assertEqual(c.num_devices, 8)
self.assertEqual(c.model_id, "Qwen/Qwen3-32B")
self.assertEqual(c.input_length, 3500)
self.assertEqual(c.output_length, 1500)
self.assertEqual(c.ttft_limits, [2000.0])
self.assertEqual(c.tpot_limits, [50.0])
self.assertEqual(c.quantize_linear_action, QuantizeLinearAction.W8A8_DYNAMIC)
self.assertEqual(c.quantize_attention_action, QuantizeAttentionAction.DISABLED)
self.assertTrue(c.do_compile)
self.assertEqual(c.mode, "agg")
def test_empty_tpot_uses_default_ms(self):
path = self._write_csv(
[
[
"test_case",
"TEST_DEVICE",
"1",
"Qwen/Qwen3-32B",
"100",
"50",
"",
"",
"",
"",
"",
"",
"0",
"",
"false",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
]
)
cases = load_cases_from_csv(path)
self.assertEqual(cases[0].tpot_limits, [DEFAULT_TPOT_LIMIT_MS])
def test_invalid_quantize_linear_raises(self):
path = self._write_csv(
[
[
"test_case",
"TEST_DEVICE",
"1",
"Qwen/Qwen3-32B",
"100",
"50",
"",
"50",
"",
"INVALID_QUANT",
"",
"",
"0",
"",
"false",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
]
)
with self.assertRaises(ValueError) as ctx:
load_cases_from_csv(path)
msg = str(ctx.exception)
self.assertIn("quantize_linear_action", msg)
self.assertIn("INVALID_QUANT", msg)
self.assertIn("Valid options:", msg)
def test_invalid_quantize_attention_raises(self):
path = self._write_csv(
[
[
"test_case",
"TEST_DEVICE",
"1",
"Qwen/Qwen3-32B",
"100",
"50",
"",
"50",
"",
"",
"BAD_ATTN",
"",
"0",
"",
"false",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
]
)
with self.assertRaises(ValueError) as ctx:
load_cases_from_csv(path)
msg = str(ctx.exception)
self.assertIn("quantize_attention_action", msg)
self.assertIn("BAD_ATTN", msg)
self.assertIn("Valid options:", msg)
def test_error_message_lists_valid_quantize_options(self):
path = self._write_csv(
[
[
"test_case",
"TEST_DEVICE",
"1",
"Qwen/Qwen3-32B",
"100",
"50",
"",
"50",
"",
"NOPE",
"",
"",
"0",
"",
"false",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
]
)
with self.assertRaises(ValueError) as ctx:
load_cases_from_csv(path)
msg = str(ctx.exception)
for action in QuantizeLinearAction:
self.assertIn(action.value, msg)
def test_no_header_raises(self):
path = os.path.join(self.tmpdir, "empty.csv")
with open(path, "w", newline="", encoding="utf-8") as f:
f.write("")
with self.assertRaises(ValueError) as ctx:
load_cases_from_csv(path)
self.assertIn("no header", str(ctx.exception))
def test_empty_rows_skipped(self):
path = self._write_csv([[""] * len(CSV_CONFIG_HEADER)])
cases = load_cases_from_csv(path)
self.assertEqual(len(cases), 0)
def test_missing_case_name_gets_row_n(self):
path = self._write_csv(
[
[
"",
"TEST_DEVICE",
"1",
"Qwen/Qwen3-32B",
"100",
"50",
"",
"50",
"",
"",
"",
"",
"0",
"",
"false",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
]
)
cases = load_cases_from_csv(path)
self.assertEqual(cases[0].case_name, "row_1")
class TestWriteTemplateCsv(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_template_has_correct_header(self):
path = os.path.join(self.tmpdir, "template.csv")
write_template_csv(path)
with open(path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader)
self.assertEqual(header, CSV_CONFIG_HEADER)
def test_template_has_example_rows(self):
path = os.path.join(self.tmpdir, "template.csv")
write_template_csv(path)
with open(path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
next(reader)
rows = list(reader)
self.assertEqual(len(rows), 3)
self.assertEqual(rows[0][0], "1card_agg_w8a8")
self.assertEqual(rows[1][0], "8card_agg_w8a8")
self.assertEqual(rows[2][0], "4card_disagg_mtp")
for row in rows:
self.assertEqual(row[3], "Qwen/Qwen3-32B")
def test_template_example_tpot_is_50ms(self):
path = os.path.join(self.tmpdir, "template.csv")
write_template_csv(path)
with open(path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
next(reader)
for row in reader:
self.assertEqual(row[7], str(int(DEFAULT_TPOT_LIMIT_MS)))
class TestBuildOptimizerArgs(unittest.TestCase):
def _make_case(self, **overrides):
defaults = dict(
case_name="test",
device="TEST_DEVICE",
num_devices=1,
model_id="Qwen/Qwen3-32B",
input_length=100,
output_length=50,
ttft_limits=[2000.0],
tpot_limits=[50.0],
)
defaults.update(overrides)
return BenchmarkCase(**defaults)
def test_agg_mode_no_disagg(self):
case = self._make_case(mode="agg")
args = _build_optimizer_args(case)
self.assertFalse(args.disagg)
def test_disagg_mode_sets_disagg_true(self):
case = self._make_case(mode="disagg")
args = _build_optimizer_args(case)
self.assertTrue(args.disagg)
def test_single_limit_applied_to_ttft(self):
case = self._make_case(ttft_limits=[2000.0])
args = _build_optimizer_args(case)
self.assertEqual(args.ttft_limits, 2000.0)
def test_single_limit_applied_to_tpot(self):
case = self._make_case(tpot_limits=[50.0])
args = _build_optimizer_args(case)
self.assertEqual(args.tpot_limits, 50.0)
def test_empty_ttft_gives_none(self):
case = self._make_case(ttft_limits=[])
args = _build_optimizer_args(case)
self.assertIsNone(args.ttft_limits)
def test_multiple_ttft_raises(self):
case = self._make_case(ttft_limits=[1.0, 2.0])
with self.assertRaises(ValueError) as ctx:
_build_optimizer_args(case)
self.assertIn("ttft_limits", str(ctx.exception))
def test_multiple_tpot_raises(self):
case = self._make_case(tpot_limits=[50.0, 100.0])
with self.assertRaises(ValueError) as ctx:
_build_optimizer_args(case)
self.assertIn("tpot_limits", str(ctx.exception))
class TestBenchmarkResult(unittest.TestCase):
def test_csv_header_matches_result_row_length(self):
header, _ = _csv_header_and_ref_row()
result = BenchmarkResult(
case_name="test",
device="DEV",
num_devices=1,
model_id="model",
input_length=100,
output_length=50,
)
row = _result_row(result)
self.assertEqual(len(header), len(row))
def test_csv_roundtrip(self):
tmpdir = tempfile.mkdtemp()
try:
result = BenchmarkResult(
case_name="test",
device="DEV",
num_devices=2,
model_id="Qwen/Qwen3-32B",
input_length=100,
output_length=50,
best_decode_tpot_ms=40.5,
best_decode_total_tps=100.0,
best_decode_tps_per_device=50.0,
best_decode_tp_size=2,
best_decode_pp_size=1,
best_decode_dp_size=1,
best_decode_concurrency=10,
)
path = os.path.join(tmpdir, "results.csv")
save_results_to_csv([result], path)
with open(path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
next(reader)
next(reader)
row = next(reader)
self.assertEqual(row[0], "test")
self.assertEqual(row[1], "DEV")
self.assertIn("40.50", row[12])
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def test_no_error_fields_in_result(self):
result = BenchmarkResult(
case_name="test",
device="DEV",
num_devices=1,
model_id="model",
input_length=100,
output_length=50,
)
self.assertFalse(hasattr(result, "best_decode_error"))
self.assertFalse(hasattr(result, "best_prefill_error"))
class TestParseArgs(unittest.TestCase):
def test_input_csv(self):
with patch.object(sys, "argv", ["prog", "--input-csv", "cases.csv"]):
input_csv, _, _, _, _ = _parse_args()
self.assertEqual(input_csv, "cases.csv")
def test_write_template(self):
with patch.object(sys, "argv", ["prog", "--write-template", "tmpl.csv"]):
_, write_template, _, _, _ = _parse_args()
self.assertEqual(write_template, "tmpl.csv")
def test_output_csv(self):
with patch.object(sys, "argv", ["prog", "--output-csv", "out.csv"]):
_, _, output_csv, _, _ = _parse_args()
self.assertEqual(output_csv, "out.csv")
def test_test_conversion_flag(self):
with patch.object(sys, "argv", ["prog", "--test-conversion"]):
_, _, _, test_conv, _ = _parse_args()
self.assertTrue(test_conv)
def test_defaults(self):
with patch.object(sys, "argv", ["prog"]):
input_csv, write_template, output_csv, test_conv, validate_csv = _parse_args()
self.assertIsNone(input_csv)
self.assertIsNone(write_template)
self.assertIsNone(output_csv)
self.assertFalse(test_conv)
self.assertIsNone(validate_csv)
def test_help_exits(self):
with patch.object(sys, "argv", ["prog", "--help"]):
with self.assertRaises(SystemExit) as ctx:
_parse_args()
self.assertEqual(ctx.exception.code, 0)
class TestSaveResultsToCsv(unittest.TestCase):
def test_results_written_correctly(self):
tmpdir = tempfile.mkdtemp()
try:
results = [
BenchmarkResult(
case_name="case1",
device="DEV1",
num_devices=1,
model_id="model1",
input_length=100,
output_length=50,
),
BenchmarkResult(
case_name="case2",
device="DEV2",
num_devices=2,
model_id="model2",
input_length=200,
output_length=100,
),
]
path = os.path.join(tmpdir, "results.csv")
save_results_to_csv(results, path)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
self.assertIn("case1", content)
self.assertIn("case2", content)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def test_batch_flush_constant(self):
self.assertEqual(FLUSH_BATCH_SIZE, 10)
class TestDefaultTpotLimitMs(unittest.TestCase):
def test_default_is_50_ms(self):
self.assertEqual(DEFAULT_TPOT_LIMIT_MS, 50.0)
class TestSafeFloat(unittest.TestCase):
def test_normal_float(self):
self.assertEqual(_safe_float(3.14), 3.14)
def test_none_returns_none(self):
self.assertIsNone(_safe_float(None))
def test_nan_returns_none(self):
self.assertIsNone(_safe_float(float("nan")))
def test_inf_returns_none(self):
self.assertIsNone(_safe_float(float("inf")))
def test_negative_inf_returns_none(self):
self.assertIsNone(_safe_float(float("-inf")))
def test_string_float(self):
self.assertEqual(_safe_float("3.14"), 3.14)
def test_invalid_string_returns_none(self):
self.assertIsNone(_safe_float("abc"))
def test_int_returns_float(self):
self.assertEqual(_safe_float(42), 42.0)
class TestRequiredColumns(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_missing_required_columns_raises(self):
path = os.path.join(self.tmpdir, "bad.csv")
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["case_name", "model_id", "input_length", "output_length"])
writer.writerow(["test", "Qwen/Qwen3-32B", "100", "50"])
with self.assertRaises(ValueError) as ctx:
load_cases_from_csv(path)
msg = str(ctx.exception)
self.assertIn("missing required columns", msg)
self.assertIn("device", msg)
def test_all_required_columns_present(self):
path = os.path.join(self.tmpdir, "good.csv")
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(CSV_CONFIG_HEADER)
writer.writerow(
[
"test",
"TEST_DEVICE",
"1",
"Qwen/Qwen3-32B",
"100",
"50",
"",
"50",
"",
"",
"",
"",
"0",
"",
"false",
"agg",
"8192",
"",
"0",
"8",
"info",
"32",
"0",
"false",
]
)
cases = load_cases_from_csv(path)
self.assertEqual(len(cases), 1)
class TestValidateCsvArg(unittest.TestCase):
def test_validate_csv_flag(self):
with patch.object(sys, "argv", ["prog", "--validate-csv", "cases.csv"]):
_, _, _, _, validate_csv = _parse_args()
self.assertEqual(validate_csv, "cases.csv")
class TestFilterBestRow(unittest.TestCase):
"""Tests for _filter_best_row using only public OptimizerSummary API."""
def _make_summary(self, df, tpot_limits=None, ttft_limits=None):
class MockDataConfig:
pass
dc = MockDataConfig()
dc.tpot_limits = tpot_limits
dc.ttft_limits = ttft_limits
class MockSummary:
data_config = dc
def get_summary_df(self):
return df
return MockSummary()
def test_returns_best_row_when_within_slo(self):
import pandas as pd
df = pd.DataFrame(
[
{"parallel": "tp1pp1dp1", "tpot": 40.0, "ttft": None, "token/s": 100.0},
{"parallel": "tp2pp1dp1", "tpot": 30.0, "ttft": None, "token/s": 200.0},
]
)
summary = self._make_summary(df, tpot_limits=50.0, ttft_limits=None)
best = _filter_best_row(summary)
self.assertIsNotNone(best)
self.assertEqual(best["parallel"], "tp2pp1dp1")
self.assertEqual(best["token/s"], 200.0)
def test_returns_none_when_no_row_meets_slo(self):
import pandas as pd
df = pd.DataFrame(
[
{
"parallel": "tp1pp1dp1",
"tpot": 100.0,
"ttft": None,
"token/s": 100.0,
},
]
)
summary = self._make_summary(df, tpot_limits=50.0, ttft_limits=None)
self.assertIsNone(_filter_best_row(summary))
def test_returns_none_for_empty_df(self):
import pandas as pd
df = pd.DataFrame(columns=["parallel", "tpot", "ttft", "token/s"])
summary = self._make_summary(df, tpot_limits=50.0, ttft_limits=None)
self.assertIsNone(_filter_best_row(summary))
def test_prefill_phase_isolates_ttft_filter(self):
import pandas as pd
df = pd.DataFrame(
[
{
"parallel": "tp1pp1dp1",
"tpot": None,
"ttft": 1500.0,
"token/s": 100.0,
},
]
)
summary = self._make_summary(df, tpot_limits=None, ttft_limits=2000.0)
best = _filter_best_row(summary)
self.assertIsNotNone(best)
self.assertEqual(best["ttft"], 1500.0)
def test_decode_phase_isolates_tpot_filter(self):
import pandas as pd
df = pd.DataFrame(
[
{"parallel": "tp1pp1dp1", "tpot": 40.0, "ttft": None, "token/s": 100.0},
]
)
summary = self._make_summary(df, tpot_limits=50.0, ttft_limits=None)
best = _filter_best_row(summary)
self.assertIsNotNone(best)
self.assertEqual(best["tpot"], 40.0)
class TestConfigureLogging(unittest.TestCase):
"""Tests for _configure_logging helper."""
def test_known_level(self):
import logging
_configure_logging("debug")
self.assertEqual(logging.getLogger().level, logging.DEBUG)
def test_unknown_level_falls_back_to_info(self):
import logging
_configure_logging("unknown_level")
self.assertEqual(logging.getLogger().level, logging.INFO)
def test_log_levels_constant_present(self):
for k in ("debug", "info", "warning", "error", "fatal", "critical"):
self.assertIn(k, LOG_LEVELS)
class TestIntegrationExampleCase(unittest.TestCase):
"""Integration test based on example_cases.csv input and benchmark_cases_results.csv output.
Input CSV row (from example_cases.csv):
deepseek-ai/DeepSeek-V3,ATLAS_800_A3_752T_128G_DIE,64,
deepseek-ai/DeepSeek-V3,3500,1000,,20,1,W8A8_DYNAMIC,DISABLED,,
3,0.9;0.6;0.4,TRUE,disagg,16000,,0,8,critical,32,0,FALSE
Expected output CSV row (from result.csv):
deepseek-ai/DeepSeek-V3,ATLAS_800_A3_752T_128G_DIE,64,3500,1000,
deepseek-ai/DeepSeek-V3,W8A8_DYNAMIC,DISABLED,,3,20.00,512,18.81,
27216.8,425.3,68.50,13.86,17.12,0.52,1,1,64,
(prefill fields empty)
"""
INPUT_CSV_ROW = [
"deepseek-ai/DeepSeek-V3",
"ATLAS_800_A3_752T_128G_DIE",
"64",
"deepseek-ai/DeepSeek-V3",
"3500",
"1000",
"",
"20",
"1",
"W8A8_DYNAMIC",
"DISABLED",
"",
"3",
"0.9;0.6;0.4",
"TRUE",
"disagg",
"16000",
"",
"0",
"8",
"critical",
"32",
"0",
"FALSE",
]
EXPECTED_CASE = dict(
case_name="deepseek-ai/DeepSeek-V3",
device="ATLAS_800_A3_752T_128G_DIE",
num_devices=64,
model_id="deepseek-ai/DeepSeek-V3",
input_length=3500,
output_length=1000,
ttft_limits=[],
tpot_limits=[20.0],
tp_sizes=[1],
quantize_linear_action=QuantizeLinearAction.W8A8_DYNAMIC,
quantize_attention_action=QuantizeAttentionAction.DISABLED,
ep_sizes=None,
num_mtp_tokens=3,
mtp_acceptance_rate=[0.9, 0.6, 0.4],
do_compile=True,
mode="disagg",
max_prefill_tokens=16000,
batch_range=None,
serving_cost=0.0,
jobs=8,
log_level="critical",
mxfp4_group_size=32,
reserved_memory_gb=0.0,
compile_allow_graph_break=False,
)
EXPECTED_RESULT = dict(
case_name="deepseek-ai/DeepSeek-V3",
device="ATLAS_800_A3_752T_128G_DIE",
num_devices=64,
model_id="deepseek-ai/DeepSeek-V3",
input_length=3500,
output_length=1000,
best_decode_linear_quant_type="W8A8_DYNAMIC",
best_decode_attn_quant_type="DISABLED",
best_decode_tp_size=1,
best_decode_pp_size=1,
best_decode_dp_size=64,
best_decode_use_ep="",
best_decode_mtp_tokens=3,
best_decode_slo_target_ms=20.0,
best_decode_concurrency=512,
best_decode_tpot_ms=18.81,
best_decode_total_tps=27216.8,
best_decode_tps_per_device=425.3,
best_decode_mem_pct="68.50",
best_decode_comm_pct="13.86",
best_decode_cube_pct="17.12",
best_decode_vec_pct="0.52",
)
EXPECTED_OUTPUT_ROW = [
"deepseek-ai/DeepSeek-V3",
"ATLAS_800_A3_752T_128G_DIE",
64,
3500,
1000,
"deepseek-ai/DeepSeek-V3",
"W8A8_DYNAMIC",
"DISABLED",
"",
"3",
"20.00",
"512",
"18.81",
"27216.8",
"425.3",
"68.50",
"13.86",
"17.12",
"0.52",
"1",
"1",
"64",
]
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdir, ignore_errors=True)
def _write_csv(self, rows, header=None):
path = os.path.join(self.tmpdir, "cases.csv")
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if header is None:
header = CSV_CONFIG_HEADER
writer.writerow(header)
for row in rows:
writer.writerow(row)
return path
def test_csv_load_parses_example_case(self):
path = self._write_csv([self.INPUT_CSV_ROW])
cases = load_cases_from_csv(path)
self.assertEqual(len(cases), 1)
c = cases[0]
for field, expected in self.EXPECTED_CASE.items():
actual = getattr(c, field)
self.assertEqual(
actual,
expected,
f"BenchmarkCase.{field}: expected {expected!r}, got {actual!r}",
)
def test_build_optimizer_args_from_example_case(self):
path = self._write_csv([self.INPUT_CSV_ROW])
cases = load_cases_from_csv(path)
args = _build_optimizer_args(cases[0])
self.assertEqual(args.device, "ATLAS_800_A3_752T_128G_DIE")
self.assertEqual(args.num_devices, 64)
self.assertEqual(args.model_id, "deepseek-ai/DeepSeek-V3")
self.assertEqual(args.input_length, 3500)
self.assertEqual(args.output_length, 1000)
self.assertEqual(args.tpot_limits, 20.0)
self.assertIsNone(args.ttft_limits)
self.assertTrue(args.disagg)
self.assertTrue(args.compile)
self.assertFalse(args.compile_allow_graph_break)
self.assertEqual(args.num_mtp_tokens, 3)
self.assertEqual(args.mtp_acceptance_rate, [0.9, 0.6, 0.4])
self.assertEqual(args.quantize_linear_action, QuantizeLinearAction.W8A8_DYNAMIC)
self.assertEqual(args.quantize_attention_action, QuantizeAttentionAction.DISABLED)
self.assertEqual(args.tp_sizes, [1])
self.assertEqual(args.max_prefill_tokens, 16000)
self.assertEqual(args.log_level, "critical")
self.assertIsNone(args.image_batch_size)
self.assertEqual(args.prefix_cache_hit_rate, 0.0)
self.assertFalse(args.enable_optimize_prefill_decode_ratio)
def test_result_row_from_example_output(self):
result = BenchmarkResult(
case_name=self.EXPECTED_RESULT["case_name"],
device=self.EXPECTED_RESULT["device"],
num_devices=self.EXPECTED_RESULT["num_devices"],
model_id=self.EXPECTED_RESULT["model_id"],
input_length=self.EXPECTED_RESULT["input_length"],
output_length=self.EXPECTED_RESULT["output_length"],
best_decode_linear_quant_type=self.EXPECTED_RESULT["best_decode_linear_quant_type"],
best_decode_attn_quant_type=self.EXPECTED_RESULT["best_decode_attn_quant_type"],
best_decode_tp_size=self.EXPECTED_RESULT["best_decode_tp_size"],
best_decode_pp_size=self.EXPECTED_RESULT["best_decode_pp_size"],
best_decode_dp_size=self.EXPECTED_RESULT["best_decode_dp_size"],
best_decode_use_ep=self.EXPECTED_RESULT["best_decode_use_ep"],
best_decode_mtp_tokens=self.EXPECTED_RESULT["best_decode_mtp_tokens"],
best_decode_slo_target_ms=self.EXPECTED_RESULT["best_decode_slo_target_ms"],
best_decode_concurrency=self.EXPECTED_RESULT["best_decode_concurrency"],
best_decode_tpot_ms=self.EXPECTED_RESULT["best_decode_tpot_ms"],
best_decode_total_tps=self.EXPECTED_RESULT["best_decode_total_tps"],
best_decode_tps_per_device=self.EXPECTED_RESULT["best_decode_tps_per_device"],
best_decode_mem_pct=self.EXPECTED_RESULT["best_decode_mem_pct"],
best_decode_comm_pct=self.EXPECTED_RESULT["best_decode_comm_pct"],
best_decode_cube_pct=self.EXPECTED_RESULT["best_decode_cube_pct"],
best_decode_vec_pct=self.EXPECTED_RESULT["best_decode_vec_pct"],
)
row = _result_row(result)
for i, expected in enumerate(self.EXPECTED_OUTPUT_ROW):
self.assertEqual(row[i], expected, f"Column {i}: expected {expected!r}, got {row[i]!r}")
for i in range(22, 38):
self.assertEqual(row[i], "", f"Prefill column {i} should be empty, got {row[i]!r}")
def test_full_csv_roundtrip(self):
path = self._write_csv([self.INPUT_CSV_ROW])
cases = load_cases_from_csv(path)
self.assertEqual(len(cases), 1)
result = BenchmarkResult(
case_name=self.EXPECTED_RESULT["case_name"],
device=self.EXPECTED_RESULT["device"],
num_devices=self.EXPECTED_RESULT["num_devices"],
model_id=self.EXPECTED_RESULT["model_id"],
input_length=self.EXPECTED_RESULT["input_length"],
output_length=self.EXPECTED_RESULT["output_length"],
best_decode_linear_quant_type=self.EXPECTED_RESULT["best_decode_linear_quant_type"],
best_decode_attn_quant_type=self.EXPECTED_RESULT["best_decode_attn_quant_type"],
best_decode_tp_size=self.EXPECTED_RESULT["best_decode_tp_size"],
best_decode_pp_size=self.EXPECTED_RESULT["best_decode_pp_size"],
best_decode_dp_size=self.EXPECTED_RESULT["best_decode_dp_size"],
best_decode_use_ep=self.EXPECTED_RESULT["best_decode_use_ep"],
best_decode_mtp_tokens=self.EXPECTED_RESULT["best_decode_mtp_tokens"],
best_decode_slo_target_ms=self.EXPECTED_RESULT["best_decode_slo_target_ms"],
best_decode_concurrency=self.EXPECTED_RESULT["best_decode_concurrency"],
best_decode_tpot_ms=self.EXPECTED_RESULT["best_decode_tpot_ms"],
best_decode_total_tps=self.EXPECTED_RESULT["best_decode_total_tps"],
best_decode_tps_per_device=self.EXPECTED_RESULT["best_decode_tps_per_device"],
best_decode_mem_pct=self.EXPECTED_RESULT["best_decode_mem_pct"],
best_decode_comm_pct=self.EXPECTED_RESULT["best_decode_comm_pct"],
best_decode_cube_pct=self.EXPECTED_RESULT["best_decode_cube_pct"],
best_decode_vec_pct=self.EXPECTED_RESULT["best_decode_vec_pct"],
)
out_path = os.path.join(self.tmpdir, "results.csv")
save_results_to_csv([result], out_path)
with open(out_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader)
ref_row = next(reader)
data_row = next(reader)
expected_header, _ = _csv_header_and_ref_row()
self.assertEqual(header, expected_header)
self.assertTrue(len(ref_row[-2]) > 0)
self.assertTrue(len(ref_row[-1]) > 0)
for i, expected in enumerate(self.EXPECTED_OUTPUT_ROW):
self.assertEqual(
data_row[i],
str(expected),
f"Column {i}: expected {expected!r}, got {data_row[i]!r}",
)
def test_parse_parallel_disagg_output(self):
tp, pp, dp = _parse_parallel("TP=1 | PP=1 | DP=64")
self.assertEqual(tp, 1)
self.assertEqual(pp, 1)
self.assertEqual(dp, 64)
if __name__ == "__main__":
unittest.main()