"""Extended unit tests for ``serving_cast/service/optimizer_summary.py`` (serving_cast UT suite).
Complements ``test_optimizer_summary.py`` in this directory with helper/render/branch coverage.
"""
import sys
from types import ModuleType, SimpleNamespace
from unittest import TestCase
from unittest.mock import patch
import pandas as pd
import serving_cast.service.optimizer_summary as optimizer_summary_module
from serving_cast.service.optimizer_summary import (
OptimizerSummary,
_compute_disagg_request_qps,
_fmt_optional,
_get_agg_table_buf,
_get_disagg_table_buf,
_get_pd_ratio_table_buf,
_positive_float,
_sorted_rows,
render_cross_device_comparison,
render_cross_hardware_disagg_decode,
render_cross_hardware_disagg_prefill,
render_cross_hardware_pd_ratio,
render_hardware_profile_comparison,
)
class TestPositiveFloat(TestCase):
def test_positive_float_accept_and_reject(self):
self.assertEqual(_positive_float(1.5), 1.5)
self.assertIsNone(_positive_float(0))
self.assertIsNone(_positive_float(None))
self.assertIsNone(_positive_float("bad"))
class TestFmtOptionalSortedRows(TestCase):
def test_fmt_optional_formats_or_dash(self):
self.assertEqual(_fmt_optional(3.14159), "3.14")
self.assertEqual(_fmt_optional(None), "-")
def test_sorted_rows_orders_by_metric(self):
rows = [{"k": 1}, {"k": 3}, {"k": 2}]
ordered = _sorted_rows(rows, "k")
self.assertEqual([r["k"] for r in ordered], [3, 2, 1])
class TestComputeDisaggRequestQps(TestCase):
def test_prefill_formula(self):
row = pd.Series({"concurrency": 10.0, "ttft": 50.0, "tpot": None})
self.assertAlmostEqual(_compute_disagg_request_qps(row, 64), 10.0 / 50.0 * 1000.0)
def test_decode_formula_requires_output_length(self):
row = pd.Series({"concurrency": 8.0, "ttft": None, "tpot": 2.0})
self.assertIsNone(_compute_disagg_request_qps(row, None))
self.assertIsNone(_compute_disagg_request_qps(row, 0))
self.assertAlmostEqual(
_compute_disagg_request_qps(row, 4),
8.0 / (2.0 * 4.0) * 1000.0,
)
def test_returns_none_when_both_ttft_and_tpot(self):
row = pd.Series({"concurrency": 1.0, "ttft": 1.0, "tpot": 1.0})
self.assertIsNone(_compute_disagg_request_qps(row, 8))
def test_returns_none_when_concurrency_invalid(self):
row = pd.Series({"concurrency": 0.0, "ttft": 50.0, "tpot": None})
self.assertIsNone(_compute_disagg_request_qps(row, 8))
class TestDisaggPdRatioTableBuf(TestCase):
def test_disagg_prefill_table_title_and_qps_cell(self):
df = pd.DataFrame(
{
"token/s": [88.0],
"ttft": [110.0],
"tpot": [pd.NA],
"concurrency": [22.0],
"num_devices": [1],
"parallel": ["tp2"],
"batch_size": [1],
}
)
buf = _get_disagg_table_buf(df, output_length=None)
self.assertRegex(buf, r"PD Disaggregated Prefill Configurations:")
expected_qps = 22.0 / 110.0 * 1000.0
self.assertIn(f"{expected_qps:.2f}", buf)
def test_disagg_decode_table_uses_decode_title(self):
df = pd.DataFrame(
{
"token/s": [50.0],
"ttft": [pd.NA],
"tpot": [2.5],
"concurrency": [5.0],
"num_devices": [1],
"parallel": ["tp1"],
"batch_size": [1],
}
)
buf = _get_disagg_table_buf(df, output_length=4)
self.assertRegex(buf, r"PD Disaggregated Decode Configurations:")
expected_qps = 5.0 / (2.5 * 4.0) * 1000.0
self.assertIn(f"{expected_qps:.2f}", buf)
class TestPdRatioTableBuf(TestCase):
def test_get_pd_ratio_table_buf_contains_banner_and_columns(self):
df = pd.DataFrame(
{
"pd_ratio": [0.5],
"balanced_qps": [12.34],
"p_qps": [10.0],
"d_qps": [20.0],
"ttft_p": [30.0],
"tpot_d": [1.5],
"parallel_p": ["Pa"],
"parallel_d": ["Da"],
"num_devices_p": [4],
"num_devices_d": [4],
"batch_size_p": [1],
"batch_size_d": [2],
"concurrency_p": [3],
"concurrency_d": [4],
}
)
buf = _get_pd_ratio_table_buf(df)
self.assertIn("PD Ratio Configurations:", buf)
self.assertIn("Balanced QPS", buf)
class TestRenderComparisonTables(TestCase):
def test_render_helpers_empty_lists(self):
self.assertEqual(render_cross_device_comparison([]), "")
self.assertEqual(render_cross_hardware_pd_ratio([]), "")
self.assertEqual(render_cross_hardware_disagg_prefill([]), "")
self.assertEqual(render_cross_hardware_disagg_decode([]), "")
def test_render_cross_device_comparison_non_empty(self):
txt = render_cross_device_comparison(
[
{
"device": "D1",
"throughput_tps": 99.9,
"concurrency": 1,
"parallel": "p",
"batch_size": 1,
"num_devices": 1,
}
]
)
self.assertIn("Cross-hardware", txt)
self.assertIn("D1", txt)
def test_render_cross_hardware_pd_ratio_shows_banner(self):
rows = [
{
"device": "X",
"balanced_qps": 1.23,
"pd_ratio": 0.25,
"p_qps": 4.0,
"d_qps": 1.0,
"ttft_p": 50.0,
"tpot_d": 2.0,
"p_instances": 2,
"d_instances": 1,
"total_devices": 8,
}
]
txt = render_cross_hardware_pd_ratio(rows)
self.assertIn("PD Ratio", txt)
self.assertIn("num-devices=", txt.lower())
def test_render_cross_hardware_disagg_prefill_decode(self):
pref = render_cross_hardware_disagg_prefill(
[
{
"device": "P1",
"throughput_tps": 80.0,
"qps_req_s": None,
"ttft_ms": 100.0,
"concurrency": 2,
}
]
)
self.assertIn("PD Disaggregated Prefill", pref)
dec = render_cross_hardware_disagg_decode(
[
{
"device": "D2",
"throughput_tps": 90.0,
"qps_req_s": 1.23,
"tpot_ms": 20.0,
"concurrency": 3,
}
]
)
self.assertIn("PD Disaggregated Decode", dec)
def test_render_cross_hardware_pd_ratio_num_devices_banner(self):
rows = [
{
"device": "Hw1",
"balanced_qps": 9.0,
"pd_ratio": 0.5,
"p_qps": 18.0,
"d_qps": 9.0,
"ttft_p": 60.0,
"tpot_d": 1.5,
"p_instances": 4,
"d_instances": 2,
"total_devices": 16,
}
]
txt = render_cross_hardware_pd_ratio(rows)
self.assertIn("--num-devices=16", txt)
class TestOptimizerSummaryBranches(TestCase):
def test_report_final_result_silent_returns_immediately(self):
cfg = SimpleNamespace(ttft_limits=10.0, tpot_limits=10.0, output_length=8)
s = OptimizerSummary(cfg)
s.set_summary_df(pd.DataFrame({"token/s": [1.0], "ttft": [1.0], "tpot": [1.0], "concurrency": [1]}))
with patch("builtins.print") as p:
s.report_final_result(SimpleNamespace(disagg=False, dump_original_results=False), silent=True)
p.assert_not_called()
def test_report_final_result_warns_when_no_summary(self):
cfg = SimpleNamespace(ttft_limits=10.0, tpot_limits=10.0, output_length=8)
s = OptimizerSummary(cfg)
with self.assertLogs("serving_cast.service.optimizer_summary", level="WARNING") as log_ctx:
s.report_final_result(SimpleNamespace(disagg=False, dump_original_results=False))
self.assertTrue(any("empty or unset" in m for m in log_ctx.output))
def test_get_agg_disagg_final_out_empty_after_filters(self):
cfg = SimpleNamespace(ttft_limits=10.0, tpot_limits=10.0, output_length=None)
s = OptimizerSummary(cfg)
s.set_summary_df(
pd.DataFrame(
{
"token/s": [1.0],
"ttft": [1000.0],
"tpot": [500.0],
"concurrency": [1],
"num_devices": [1],
"parallel": ["x"],
"batch_size": [1],
}
)
)
args = SimpleNamespace(
model_id="m",
num_devices=1,
device="TEST",
quantize_linear_action="DISABLED",
quantize_attention_action="DISABLED",
disagg=False,
)
with self.assertLogs("serving_cast.service.optimizer_summary", level="WARNING") as log_ctx:
out = s._get_agg_disagg_final_out(args)
self.assertTrue(any("TTFT/TPOT filters" in m for m in log_ctx.output))
self.assertIn("No configurations satisfy", "\n".join(out))
def test_collect_comparison_row_via_best_agg_disagg_row(self):
cfg = SimpleNamespace(ttft_limits=1000.0, tpot_limits=50.0, output_length=32)
s = OptimizerSummary(cfg)
s.set_summary_df(
pd.DataFrame(
{
"token/s": [10.0, 30.0],
"ttft": [90.0, 80.0],
"tpot": [5.0, 6.0],
"concurrency": [1, 2],
"num_devices": [8, 8],
"parallel": ["tp1", "tp2"],
"batch_size": [1, 1],
}
)
)
row = s.collect_comparison_row("device_a")
self.assertEqual(row["device"], "device_a")
self.assertEqual(row["throughput_tps"], 30.0)
def test_collect_disagg_prefill_decode_guards(self):
cfg_no_ttft = SimpleNamespace(ttft_limits=None, tpot_limits=None, output_length=None)
self.assertIsNone(OptimizerSummary(cfg_no_ttft).collect_disagg_prefill_row("d"))
cfg_ttft_and_tpot = SimpleNamespace(ttft_limits=100.0, tpot_limits=10.0, output_length=None)
self.assertIsNone(OptimizerSummary(cfg_ttft_and_tpot).collect_disagg_prefill_row("d"))
cfg_decode_but_ttft_set = SimpleNamespace(ttft_limits=100.0, tpot_limits=10.0, output_length=4)
self.assertIsNone(OptimizerSummary(cfg_decode_but_ttft_set).collect_disagg_decode_row("d"))
def test_collect_pd_ratio_comparison_row_needs_pd_mode_and_data(self):
cfg_plain = SimpleNamespace(ttft_limits=100.0, tpot_limits=10.0, num_devices=None)
s_plain = OptimizerSummary(cfg_plain)
s_plain.set_summary_df(pd.DataFrame({"x": [1]}))
self.assertIsNone(s_plain.collect_pd_ratio_comparison_row("d"))
def _baseline_agg_row(updates=None):
base = {
"token/s": 10.0,
"ttft": 120.0,
"tpot": 40.0,
"concurrency": 8,
"num_devices": 4,
"parallel": "tp1",
"batch_size": 1,
}
if updates:
base.update(updates)
return base
def _baseline_pd_ratio_row(**overrides):
base = {
"balanced_qps": 222.222,
"pd_ratio": 0.625,
"p_qps": 40.0,
"d_qps": 25.0,
"ttft_p": 200.0,
"tpot_d": 12.5,
"parallel_p": "Pa",
"parallel_d": "Db",
"num_devices_p": 2,
"num_devices_d": 4,
"batch_size_p": 1,
"batch_size_d": 2,
"concurrency_p": 4,
"concurrency_d": 16,
}
base.update(overrides)
return base
class TestOptimizerSummaryEarlyStopHelpers(TestCase):
def test_set_get_summary_accessor(self):
cfg = SimpleNamespace(ttft_limits=1.0, tpot_limits=1.0)
obj = OptimizerSummary(cfg)
df = pd.DataFrame({"z": [1]})
obj.set_summary_df(df)
pd.testing.assert_frame_equal(obj.get_summary_df(), df)
def test_early_stop_flags(self):
cfg = SimpleNamespace(ttft_limits=50.0, tpot_limits=30.0, output_length=8)
s = OptimizerSummary(cfg)
s.set_early_stop_flag(memory_left=-1, tpot=None, ttft=None)
self.assertTrue(s.check_early_stop_flag())
s.set_early_stop_flag(memory_left=1, tpot=None, ttft=None)
self.assertFalse(s.check_early_stop_flag())
s.set_early_stop_flag(memory_left=1, tpot=60.0, ttft=None)
self.assertTrue(s.check_early_stop_flag())
s.set_early_stop_flag(memory_left=1, tpot=None, ttft=100.0)
self.assertTrue(s.check_early_stop_flag())
cfg_no_limits = SimpleNamespace(ttft_limits=None, tpot_limits=None, output_length=None)
s2 = OptimizerSummary(cfg_no_limits)
s2.set_early_stop_flag(memory_left=1, tpot=999.0, ttft=999.0)
self.assertFalse(s2.check_early_stop_flag())
class TestOptimizerSummaryReportAndCollect(TestCase):
def test_best_agg_disabled_in_pd_ratio_mode(self):
cfg = SimpleNamespace(
ttft_limits=1000.0,
tpot_limits=100.0,
output_length=8,
prefill_devices_per_instance=2,
decode_devices_per_instance=2,
)
s = OptimizerSummary(cfg)
s.set_summary_df(pd.DataFrame([_baseline_agg_row()]))
self.assertIsNone(s.collect_comparison_row("x"))
def test_collect_disagg_prefill_and_decode_success(self):
pref_cfg = SimpleNamespace(ttft_limits=500.0, tpot_limits=None, output_length=8)
s_pref = OptimizerSummary(pref_cfg)
s_pref.set_summary_df(
pd.DataFrame(
[
_baseline_agg_row(
{
"token/s": 111.1,
"ttft": 200.0,
"tpot": float("nan"),
"parallel": "pref",
}
),
]
)
)
row_p = s_pref.collect_disagg_prefill_row("Pdev")
self.assertEqual(row_p["device"], "Pdev")
self.assertAlmostEqual(row_p["throughput_tps"], 111.1)
dec_cfg = SimpleNamespace(ttft_limits=None, tpot_limits=50.0, output_length=4)
s_dec = OptimizerSummary(dec_cfg)
s_dec.set_summary_df(
pd.DataFrame(
[
_baseline_agg_row(
{
"token/s": 50.0,
"tpot": 5.0,
"ttft": float("nan"),
"parallel": "dec",
}
),
]
)
)
row_d = s_dec.collect_disagg_decode_row("Ddev")
self.assertEqual(row_d["device"], "Ddev")
def test_row_dict_na_latency_fields(self):
cfg = SimpleNamespace(ttft_limits=None, tpot_limits=None, output_length=None)
s = OptimizerSummary(cfg)
s.set_summary_df(
pd.DataFrame(
[
_baseline_agg_row(
{
"token/s": 333.3,
"ttft": pd.NA,
"tpot": pd.NA,
"parallel": "na_row",
"concurrency": 9,
"num_devices": 1,
"batch_size": 2,
}
),
]
)
)
rc = s.collect_comparison_row("uut")
self.assertIsNone(rc["ttft_ms"])
self.assertIsNone(rc["tpot_ms"])
self.assertEqual(rc["parallel"], "na_row")
def test_prepare_pd_ratio_dedupe_and_comparison_row_instances(self):
cfg = SimpleNamespace(
ttft_limits=9999.0,
tpot_limits=9999.0,
num_devices=32,
prefill_devices_per_instance=2,
decode_devices_per_instance=4,
)
s = OptimizerSummary(cfg)
r0 = _baseline_pd_ratio_row(
balanced_qps=500.501,
parallel_p="P1",
parallel_d="D1",
)
r1 = _baseline_pd_ratio_row(
balanced_qps=490.1,
parallel_p="P1",
parallel_d="D9",
pd_ratio=0.75,
num_devices_d=8,
num_devices_p=2,
)
dup_balanced = dict(r1)
dup_balanced["balanced_qps"] = r1["balanced_qps"] + 0.008
s.set_summary_df(pd.DataFrame([r0, r1, dup_balanced]))
filt = s._prepare_pd_ratio_results()
self.assertFalse(filt.empty)
self.assertTrue((filt["balanced_qps"] <= 500.601).any())
comp = s.collect_pd_ratio_comparison_row("hw-X")
self.assertEqual(comp["device"], "hw-X")
self.assertAlmostEqual(comp["balanced_qps"], filt.iloc[0]["balanced_qps"], places=6)
self.assertIsNotNone(comp.get("p_instances"))
self.assertIsNotNone(comp.get("d_instances"))
self.assertEqual(comp.get("total_devices"), 32)
def test_get_agg_disagg_final_out_disagg_branch(self):
cfg = SimpleNamespace(ttft_limits=500.0, tpot_limits=40.0, output_length=None)
s = OptimizerSummary(cfg)
s.set_summary_df(
pd.DataFrame(
[
_baseline_agg_row(
{
"token/s": 555.5,
"ttft": 35.0,
"tpot": 8.0,
"parallel": "pref",
}
),
]
)
)
args = SimpleNamespace(
model_id="m",
num_devices=32,
device="DEVICE",
quantize_linear_action="OFF",
quantize_attention_action="OFF",
disagg=True,
)
out = s._get_agg_disagg_final_out(args)
joined = "\n".join(out)
self.assertIn("PD Disaggregated Prefill", joined)
def test_get_agg_table_buf_contains_rows(self):
df = pd.DataFrame(
[
_baseline_agg_row({"token/s": 777.77, "ttft": 10.0, "tpot": 3.33, "parallel": "px"}),
_baseline_agg_row({"token/s": 5.5, "ttft": 20.0, "tpot": 1.1, "parallel": "py"}),
]
)
buf = _get_agg_table_buf(df)
self.assertIn("Aggregated Configurations", buf)
self.assertIn("777.77", buf)
def test_report_final_agg_dump_original_and_normal(self):
cfg = SimpleNamespace(ttft_limits=500.0, tpot_limits=40.0, output_length=8)
s = OptimizerSummary(cfg)
s.set_summary_df(pd.DataFrame([_baseline_agg_row({"token/s": 12.34})]))
dump_args = SimpleNamespace(
disagg=False,
dump_original_results=True,
model_id="_",
num_devices=1,
device="-",
quantize_linear_action="",
quantize_attention_action="",
)
with patch("builtins.print") as pr:
s.report_final_result(dump_args, silent=False)
self.assertGreaterEqual(pr.call_count, 1)
norm_args = SimpleNamespace(
disagg=False,
dump_original_results=False,
model_id="m",
num_devices=1,
device="dev",
quantize_linear_action="QL",
quantize_attention_action="QA",
)
with patch("builtins.print") as pr2:
s.report_final_result(norm_args, silent=False)
merged = "".join(call.args[0] for call in pr2.call_args_list if call.args)
self.assertIn("Overall Best", merged)
def test_report_pd_ratio_dump_empty_filtered_infos(self):
cfg = SimpleNamespace(
ttft_limits=500.0,
tpot_limits=40.0,
output_length=8,
prefill_devices_per_instance=4,
decode_devices_per_instance=4,
num_devices=128,
)
s = OptimizerSummary(cfg)
s.set_summary_df(
pd.DataFrame(
[
_baseline_pd_ratio_row(
balanced_qps=999.99,
ttft_p=1e9,
tpot_d=1e9,
parallel_p="_",
parallel_d="_",
num_devices_p=2,
num_devices_d=2,
batch_size_p=1,
batch_size_d=1,
concurrency_p=1,
concurrency_d=1,
),
]
)
)
args = SimpleNamespace(dump_original_results=True, device="CARD", model_id="MID")
with self.assertLogs(optimizer_summary_module.logger, level="INFO"), patch("builtins.print"):
s.report_final_result(args, silent=False)
def test_report_pd_ratio_dump_non_empty_df(self):
cfg = SimpleNamespace(
ttft_limits=800.0,
tpot_limits=80.0,
output_length=8,
prefill_devices_per_instance=2,
decode_devices_per_instance=2,
num_devices=None,
)
row = dict(_baseline_pd_ratio_row(parallel_p="PP", parallel_d="DD"))
s = OptimizerSummary(cfg)
s.set_summary_df(pd.DataFrame([row]))
args_dump = SimpleNamespace(dump_original_results=True, device="CARD", model_id="MID")
captured = ""
def _capture(*parts, **_kwargs):
nonlocal captured
captured += " ".join(str(p) for p in parts)
with patch("builtins.print", side_effect=_capture):
s.report_final_result(args_dump, silent=False)
self.assertIn("balanced_qps", captured)
def test_report_pd_ratio_pretty_best_with_instances(self):
cfg = SimpleNamespace(
ttft_limits=450.0,
tpot_limits=45.0,
output_length=8,
prefill_devices_per_instance=2,
decode_devices_per_instance=2,
num_devices=32,
)
row = dict(_baseline_pd_ratio_row(pd_ratio=0.5))
df = pd.DataFrame([row])
s = OptimizerSummary(cfg)
s.set_summary_df(df)
filt = s._prepare_pd_ratio_results()
self.assertFalse(filt.empty)
fout = s._get_pd_ratio_final_out(
SimpleNamespace(model_id="model-x", device="mydev"),
filt,
)
body = "\n".join(fout)
self.assertIn("model-x", body)
self.assertIn("P Instances:", body)
def test_report_pd_ratio_pretty_not_dump_calls_print(self):
cfg = SimpleNamespace(
ttft_limits=450.0,
tpot_limits=45.0,
output_length=8,
prefill_devices_per_instance=2,
decode_devices_per_instance=2,
num_devices=32,
)
s = OptimizerSummary(cfg)
s.set_summary_df(pd.DataFrame([dict(_baseline_pd_ratio_row(pd_ratio=0.5))]))
args = SimpleNamespace(
dump_original_results=False,
device="CARD",
model_id="MID",
quantize_linear_action="OFF",
quantize_attention_action="OFF",
)
with patch("builtins.print") as pr:
s.report_final_result(args, silent=False)
self.assertGreaterEqual(pr.call_count, 1)
def test_collect_comparison_returns_none_when_all_rows_filtered_out(self):
cfg = SimpleNamespace(ttft_limits=1e-6, tpot_limits=1e-6, output_length=None)
s = OptimizerSummary(cfg)
s.set_summary_df(
pd.DataFrame([_baseline_agg_row({"token/s": 999.9, "ttft": 1e9, "tpot": 1e9, "parallel": "gone"})])
)
self.assertIsNone(s.collect_comparison_row("dev"))
def test_collect_pd_ratio_summary_unset_or_filtered_returns_none(self):
pd_cfg = SimpleNamespace(
ttft_limits=100.0,
tpot_limits=100.0,
output_length=None,
prefill_devices_per_instance=1,
decode_devices_per_instance=1,
num_devices=None,
)
unset = OptimizerSummary(pd_cfg)
self.assertIsNone(unset.collect_pd_ratio_comparison_row("d"))
s = OptimizerSummary(pd_cfg)
s.set_summary_df(
pd.DataFrame(
[
dict(_baseline_pd_ratio_row(ttft_p=1e9, tpot_d=1e9)),
]
)
)
self.assertIsNone(s.collect_pd_ratio_comparison_row("d"))
class TestRenderHardwareProfileComparisonStubbedImports(TestCase):
"""Exercise ``render_hardware_profile_comparison`` inner branches without requiring ``torch``."""
_saved_sys_modules: dict
def setUp(self):
self._saved_sys_modules = dict(sys.modules)
def tearDown(self):
extras = [
k
for k in list(sys.modules)
if k not in self._saved_sys_modules and (k.startswith("tensor_cast") or k == "torch")
]
for k in extras:
sys.modules.pop(k, None)
sys.modules.clear()
sys.modules.update(self._saved_sys_modules)
def _stub_modules_for_hardware_render(self, profile_map, tor_stub):
class DeviceProfile:
all_device_profiles = profile_map
dev_pkg = ModuleType("tensor_cast.device")
dev_pkg.DeviceProfile = DeviceProfile
tc_pkg = ModuleType("tensor_cast")
tc_pkg.__path__ = []
device_profiles_stub = ModuleType("tensor_cast.device_profiles")
tc_pkg.device_profiles = device_profiles_stub
sys.modules["torch"] = tor_stub
sys.modules["tensor_cast"] = tc_pkg
sys.modules["tensor_cast.device_profiles"] = device_profiles_stub
sys.modules["tensor_cast.device"] = dev_pkg
def test_render_profiles_hits_torch_branch_and_notes(self):
bf = object()
hf = object()
tor = ModuleType("torch")
tor.bfloat16 = bf
tor.half = hf
def _grid(shape):
g = SimpleNamespace()
g.shape = shape
return g
prof_full = SimpleNamespace(
name="full_bf16",
mma_ops={bf: 400e12},
gp_ops={
bf: 80e12,
hf: 50e12,
},
compute_efficiency=0.93,
memory_bandwidth_bytes_ps=900e9,
memory_efficiency=0.88,
memory_size_bytes=96 * (1024**3),
comm_grid=SimpleNamespace(
topologies={
0: SimpleNamespace(bandwidth_bytes_ps=120e9, comm_efficiency=0.8),
3: SimpleNamespace(bandwidth_bytes_ps=60e9, comm_efficiency=0.92),
},
grid=_grid((2, 4)),
),
)
prof_half_only = SimpleNamespace(
name="half_only",
mma_ops={hf: 200e12},
gp_ops={},
compute_efficiency=1.0,
memory_bandwidth_bytes_ps=450e9,
memory_efficiency=0.5,
memory_size_bytes=32 * (1024**3),
comm_grid=SimpleNamespace(
topologies={1: SimpleNamespace(bandwidth_bytes_ps=30e9, comm_efficiency=0.61)},
grid=_grid((8,)),
),
)
prof_empty_peak = SimpleNamespace(
name="empty_ops",
mma_ops={},
gp_ops={},
compute_efficiency=1.0,
memory_bandwidth_bytes_ps=210e9,
memory_efficiency=1.0,
memory_size_bytes=128 * (1024**3),
comm_grid=SimpleNamespace(topologies={}, grid=_grid((1, 1))),
)
profiles = {
prof_full.name: prof_full,
prof_half_only.name: prof_half_only,
prof_empty_peak.name: prof_empty_peak,
}
self._stub_modules_for_hardware_render(profiles, tor)
txt = render_hardware_profile_comparison(
[
"missing_device",
prof_full.name,
prof_half_only.name,
prof_empty_peak.name,
prof_full.name,
]
)
self.assertIn("missing_device", txt)
self.assertIn("Notes:", txt)
self.assertIn("empty_ops", txt)
self.assertIn(prof_half_only.name, txt)
self.assertGreaterEqual(txt.count(prof_full.name), 1)
def test_gp_ops_max_peak_when_bf16_half_missing(self):
alt = object()
tor = ModuleType("torch")
tor.bfloat16 = object()
tor.half = object()
def _grid(shape):
g = SimpleNamespace()
g.shape = shape
return g
prof = SimpleNamespace(
name="mixed_keys",
mma_ops={alt: 500e12},
gp_ops={"z_other": 200e11},
compute_efficiency=1.0,
memory_bandwidth_bytes_ps=320e9,
memory_efficiency=1.0,
memory_size_bytes=128 * (1024**3),
comm_grid=SimpleNamespace(
topologies={0: SimpleNamespace(bandwidth_bytes_ps=10e9, comm_efficiency=0.55)},
grid=_grid((4, 2)),
),
)
profiles = {prof.name: prof}
self._stub_modules_for_hardware_render(profiles, tor)
txt = render_hardware_profile_comparison([prof.name])
self.assertIn(prof.name, txt)
self.assertIn("500.00", txt)
class TestRenderHardwareProfileShortcuts(TestCase):
def test_empty_device_name_list_returns_empty_string(self):
self.assertEqual(render_hardware_profile_comparison([]), "")