import unittest
from concurrent.futures.process import BrokenProcessPool
from unittest.mock import MagicMock, Mock, patch
from serving_cast.parallel_runner import ParallelRunner
from serving_cast.service.optimizer_summary import OptimizerSummary
from serving_cast.service.utils import OptimizerData
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.device import DeviceProfile
from .test_common import SimpleArgs
class RuntimeErrorExecutor:
def __init__(self, max_workers=None, initializer=None):
self.initializer = initializer
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def map(self, fn, *iterables, timeout=None, chunksize=1):
if self.initializer is not None:
self.initializer()
class BrokenResultIterator:
def __iter__(self_inner):
return self_inner
def __next__(self_inner):
raise BrokenProcessPool
return BrokenResultIterator()
class TestTaskRunner(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
self.args = SimpleArgs()
self.args.serving_cost = 0
self.args.jobs = 4
self.device_profile = DeviceProfile.all_device_profiles[self.args.device]
def test_get_user_config_multiple_tp(self):
"""Test _get_user_config with multiple TP values"""
self.args.tp_sizes = [2, 4]
self.args.num_devices = 4
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
self.assertEqual(len(configs), 2)
tps = [config.tp_size for config in configs]
self.assertIn(2, tps)
self.assertIn(4, tps)
for config in configs:
self.assertEqual(config.ep_size, 4)
self.assertEqual(config.moe_dp_size, 1)
def test_optimizer_data_loads_length_distribution_from_input_length_file(self):
self.args.input_length = "serving_cast/example/length_distribution.yaml"
task_runner = ParallelRunner(self.args)
self.assertIsNone(task_runner.optimizer_data.input_length)
self.assertIsNotNone(task_runner.optimizer_data.length_distribution)
def test_get_user_config_default_tps(self):
"""Test _get_user_config with default TP values"""
self.args.tp_sizes = []
self.args.num_devices = 8
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
expected_tps = [1, 2, 4, 8]
actual_tps = [config.tp_size for config in configs]
for expected_tp in expected_tps:
self.assertIn(expected_tp, actual_tps)
def test_get_user_config_tp_ep_combinations(self):
"""Test searching TP/EP with fixed MOE-DP=1."""
self.args.tp_sizes = [1, 2, 4]
self.args.ep_sizes = [1, 2, 4]
self.args.num_devices = 4
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
self.assertEqual(len(configs), 9)
target = next(config for config in configs if config.tp_size == 2 and config.ep_size == 2)
self.assertEqual(target.dp_size, 2)
self.assertEqual(target.moe_dp_size, 1)
self.assertEqual(target.moe_tp_size, 2)
def test_get_user_config_tp_ep_default_ranges(self):
"""Test TP/EP default ranges."""
self.args.tp_sizes = []
self.args.ep_sizes = []
self.args.num_devices = 8
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
self.assertEqual(len(configs), 16)
target = next(config for config in configs if config.tp_size == 8 and config.ep_size == 8)
self.assertEqual(target.dp_size, 1)
self.assertEqual(target.moe_dp_size, 1)
self.assertEqual(target.moe_tp_size, 1)
def test_get_user_config_tp_ep_moe_dp_combinations(self):
"""Test searching TP/EP/MOE-DP combinations."""
self.args.tp_sizes = [1, 2]
self.args.ep_sizes = [1, 2, 4]
self.args.moe_dp_sizes = [1, 2, 4]
self.args.num_devices = 8
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
keys = {(config.tp_size, config.ep_size, config.moe_dp_size) for config in configs}
self.assertIn((1, 2, 4), keys)
self.assertIn((2, 4, 2), keys)
for config in configs:
self.assertEqual(
config.moe_tp_size,
self.args.num_devices // (config.ep_size * config.moe_dp_size),
)
def test_get_user_config_tp_ep_moe_dp_default_ranges(self):
"""Test TP/EP/MOE-DP default ranges."""
self.args.tp_sizes = []
self.args.ep_sizes = []
self.args.moe_dp_sizes = []
self.args.num_devices = 4
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
self.assertEqual(len(configs), 18)
keys = {(config.tp_size, config.ep_size, config.moe_dp_size) for config in configs}
self.assertIn((4, 4, 1), keys)
self.assertIn((2, 2, 2), keys)
def test_get_user_config_num_mtp_tokens_combinations(self):
"""Test searching num_mtp_tokens together with parallel candidates."""
self.args.tp_sizes = [1, 2]
self.args.num_devices = 2
self.args.num_mtp_token_sizes = [0, 2]
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
self.assertEqual(len(configs), 4)
keys = {(config.tp_size, config.num_mtp_tokens) for config in configs}
self.assertEqual(keys, {(1, 0), (1, 2), (2, 0), (2, 2)})
def test_get_user_config_chrome_trace_names_include_num_mtp_tokens(self):
"""Test MTP search candidates do not overwrite the same chrome trace file."""
self.args.tp_sizes = [1]
self.args.num_devices = 1
self.args.num_mtp_token_sizes = [0, 2]
self.args.chrome_trace = "trace.json"
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
trace_names = {config.num_mtp_tokens: config.chrome_trace for config in configs}
self.assertEqual(trace_names[0], "trace_tp1dp1mtp0.json")
self.assertEqual(trace_names[2], "trace_tp1dp1mtp2.json")
self.assertEqual(len(set(trace_names.values())), 2)
def test_get_user_config_tp_ep_num_mtp_tokens_combinations(self):
"""Test TP/EP/MTP search combinations for the throughput optimizer CLI pattern."""
self.args.tp_sizes = [1, 2]
self.args.ep_sizes = [1, 2]
self.args.num_mtp_token_sizes = [1, 2, 3]
self.args.num_devices = 8
task_runner = ParallelRunner(self.args)
configs = list(task_runner._get_user_config())
self.assertEqual(len(configs), 12)
keys = {(config.tp_size, config.ep_size, config.num_mtp_tokens) for config in configs}
self.assertIn((1, 1, 1), keys)
self.assertIn((1, 2, 3), keys)
self.assertIn((2, 1, 2), keys)
self.assertIn((2, 2, 3), keys)
def test_optimizer_data_uses_safe_num_mtp_tokens_for_multi_candidate_search(self):
"""Test base OptimizerData does not pin the first MTP candidate before task dispatch."""
self.args.num_mtp_tokens = 1
self.args.num_mtp_token_sizes = [1, 2, 3]
task_runner = ParallelRunner(self.args)
self.assertEqual(task_runner.optimizer_data.num_mtp_tokens, 0)
configs = list(task_runner._get_user_config())
self.assertEqual({config.num_mtp_tokens for config in configs}, {1, 2, 3})
def test_run_with_tpot_limit(self):
"""Test run method with TPOT limit"""
self.args.tpot_limits = 50
self.args.batch_range = [2, 2]
task_runner = ParallelRunner(self.args)
result = task_runner.run_agg()
self.assertEqual(len(result), 1)
self.assertIsInstance(result[0], OptimizerSummary)
summary_df = result[0].get_summary_df()
row = summary_df.iloc[0]
self.assertEqual(row["concurrency"], 2)
def test_given_mocked_executor_when_called_then_returns_empty_list_and_verifies_executor_initialization(
self,
):
executor_cls = Mock()
executor_inst = MagicMock()
executor_cls.return_value = executor_inst
executor_inst.__enter__.return_value = executor_inst
executor_inst.__exit__.return_value = None
initializer = Mock()
def test_map(fn, *iterables, timeout=None, chunksize=1):
initializer()
return []
executor_inst.map = test_map
task_runner = ParallelRunner(self.args, executor_cls, initializer)
df_list = task_runner._get_df_list(task_runner.optimizer_data)
executor_cls.assert_called_once_with(max_workers=self.args.jobs, initializer=initializer)
initializer.assert_called_once_with()
self.assertEqual(df_list, [])
def test_given_worker_initializer_raises_runtime_error_when_called_then_raises_and_logs_expected_errors(
self,
):
initializer = Mock()
task_runner = ParallelRunner(
self.args,
executor_class=RuntimeErrorExecutor,
worker_initializer=initializer,
)
with self.assertLogs("serving_cast.parallel_runner", "ERROR") as cm:
self.assertRaises(RuntimeError, task_runner._get_df_list, task_runner.optimizer_data)
self.assertTrue(len(cm.output), 3)
self.assertRegex(
cm.output[0],
"ERROR:serving_cast.parallel_runner:A worker process crashed unexpectedly during execution. "
"Common causes: memory issues, unpicklable objects, or unhandled exceptions in worker.",
)
self.assertRegex(
cm.output[1],
"ERROR:serving_cast.parallel_runner:Executor: RuntimeErrorExecutor, Workers: 4",
)
def test_run_disagg_with_ttft_and_tpot_limit(self):
"""Test run_disagg method with ttft and tpot limit"""
self.args.ttft_limits = 1000
self.args.tpot_limits = 50
self.args.batch_range = [2, 2]
self.args.disagg = True
task_runner = ParallelRunner(self.args)
result = task_runner.run_disagg()
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], OptimizerSummary)
prefill_df = result[0].get_summary_df()
row = prefill_df.iloc[0]
self.assertEqual(row["concurrency"], 2)
self.assertIsNone(row["tpot"])
decode_df = result[1].get_summary_df()
row = decode_df.iloc[0]
self.assertEqual(row["concurrency"], 2)
self.assertIsNone(row["ttft"])
def test_submit_task(self):
"""Test _submit_task method"""
user_config = UserInputConfig.from_args(self.args)
optimizer_data = OptimizerData(
input_length=self.args.input_length,
output_length=self.args.output_length,
ttft_limits=1000,
tpot_limits=50,
max_batched_tokens=self.args.max_batched_tokens,
num_devices=self.args.num_devices,
num_mtp_tokens=1,
mtp_acceptance_rate=[0.9],
)
task_runner = ParallelRunner(self.args)
result = task_runner._submit_task(user_config, optimizer_data)
self.assertIsNotNone(result)
self.assertIsInstance(result, OptimizerSummary)
row = result.get_summary_df().iloc[0]
self.assertEqual(row["model_id"], self.args.model_id)
self.assertEqual(row["parallel"], "TP=1 | PP=1 | DP=1")
class TestParallelRunnerPDMode(unittest.TestCase):
"""Test cases for ParallelRunner PD ratio mode."""
def setUp(self):
"""Set up test fixtures for PD mode."""
self.args = SimpleArgs()
self.args.serving_cost = 0
self.args.jobs = 4
self.args.enable_optimize_prefill_decode_ratio = True
self.args.prefill_devices_per_instance = 4
self.args.decode_devices_per_instance = 2
self.args.input_length = 1024
self.args.output_length = 1024
self.args.ttft_limits = 100
self.args.tpot_limits = 10
self.args.num_devices = 8
self.args.batch_range = [1, 16]
def test_add_summary_result_with_empty_list(self):
"""Test _add_summary_result with empty df_list."""
task_runner = ParallelRunner(self.args)
optimizer_data = OptimizerData(
input_length=1024,
output_length=1024,
ttft_limits=100,
tpot_limits=10,
)
task_runner._add_summary_result([], optimizer_data)
self.assertEqual(len(task_runner.summary_result), 0)
def test_add_summary_result_with_valid_df(self):
"""Test _add_summary_result with valid DataFrame."""
import pandas as pd
task_runner = ParallelRunner(self.args)
optimizer_data = OptimizerData(
input_length=1024,
output_length=1024,
ttft_limits=100,
tpot_limits=10,
)
df = pd.DataFrame(
{
"ttft": [100.0],
"tpot": [10.0],
"concurrency": [10],
"parallel": ["tp4pp1dp1"],
"batch_size": [4],
}
)
summary = OptimizerSummary(optimizer_data)
summary.set_summary_df(df)
task_runner._add_summary_result([summary], optimizer_data)
self.assertEqual(len(task_runner.summary_result), 1)
def test_add_summary_result_selects_tightest_memory_info(self):
"""Merged multi-TP summaries should keep the most constrained memory info."""
import pandas as pd
task_runner = ParallelRunner(self.args)
optimizer_data = OptimizerData(
input_length=1024,
output_length=1024,
ttft_limits=100,
tpot_limits=10,
)
df = pd.DataFrame({"ttft": [100.0], "tpot": [10.0], "token/s": [1.0]})
loose_memory_info = {
"total_device_memory_gb": 64.0,
"reserved_memory_gb": 4.0,
"device_memory_available_gb": 8.0,
}
tight_memory_info = {
"total_device_memory_gb": 48.0,
"reserved_memory_gb": 3.0,
"device_memory_available_gb": 4.0,
}
first = OptimizerSummary(optimizer_data)
first.set_summary_df(df)
second = OptimizerSummary(optimizer_data)
second.set_summary_df(df)
second.set_memory_info(loose_memory_info)
third = OptimizerSummary(optimizer_data)
third.set_summary_df(df)
third.set_memory_info(tight_memory_info)
task_runner._add_summary_result([first, second, third], optimizer_data)
self.assertEqual(task_runner.summary_result[0].get_memory_info(), tight_memory_info)
def test_run_pd_ratio_combines_prefill_and_decode_results(self):
"""_run_pd_ratio should submit both phases and wrap the optimized result."""
import pandas as pd
class ImmediateFuture:
def __init__(self, result):
self._result = result
def result(self):
return self._result
class ImmediateThreadPool:
def __init__(self, max_workers=None):
self.max_workers = max_workers
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def submit(self, fn, *args, **kwargs):
return ImmediateFuture(fn(*args, **kwargs))
class RecordingPDRatioOptimizer:
instances = []
def __init__(self, output_length):
self.output_length = output_length
self.p_df = None
self.d_df = None
self.instances.append(self)
def set_p_results(self, p_df):
self.p_df = p_df
def set_d_results(self, d_df):
self.d_df = d_df
def optimize(self):
return pd.DataFrame(
{
"balanced_qps": [12.0],
"pd_ratio": [0.5],
"p_qps": [24.0],
"d_qps": [12.0],
}
)
task_runner = ParallelRunner(self.args)
phase_calls = []
def fake_run_pd_phase(devices_per_instance, is_prefill):
phase_calls.append((devices_per_instance, is_prefill))
if is_prefill:
df = pd.DataFrame({"p_qps": [24.0]})
df.attrs["memory_info"] = {
"total_device_memory_gb": 64.0,
"reserved_memory_gb": 4.0,
"device_memory_available_gb": 8.0,
}
return df
df = pd.DataFrame({"d_qps": [12.0]})
df.attrs["memory_info"] = {
"total_device_memory_gb": 48.0,
"reserved_memory_gb": 3.0,
"device_memory_available_gb": 3.0,
}
return df
task_runner._run_pd_phase = fake_run_pd_phase
with (
patch("serving_cast.parallel_runner.ThreadPoolExecutor", ImmediateThreadPool),
patch("serving_cast.parallel_runner.PDRatioThroughputOptimizer", RecordingPDRatioOptimizer),
):
result = task_runner._run_pd_ratio()
self.assertEqual(
phase_calls,
[
(self.args.prefill_devices_per_instance, True),
(self.args.decode_devices_per_instance, False),
],
)
optimizer = RecordingPDRatioOptimizer.instances[0]
self.assertEqual(optimizer.output_length, self.args.output_length)
self.assertEqual(optimizer.p_df.iloc[0]["p_qps"], 24.0)
self.assertEqual(optimizer.d_df.iloc[0]["d_qps"], 12.0)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].get_summary_df().iloc[0]["balanced_qps"], 12.0)
self.assertEqual(result[0].get_memory_info()["total_device_memory_gb"], 48.0)
self.assertEqual(result[0].get_memory_info()["device_memory_available_gb"], 3.0)
def test_pd_phase_forces_disaggregation_strategy(self):
"""PD ratio sub-phases should use disaggregated optimizer semantics."""
import pandas as pd
class InlineExecutor:
def __init__(self, max_workers=None, initializer=None):
self.initializer = initializer
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def map(self, fn, *iterables, timeout=None, chunksize=1):
if self.initializer is not None:
self.initializer()
return [fn(item) for item in iterables[0]]
class RecordingParallelRunner(ParallelRunner):
def __init__(self, args):
super().__init__(args, executor_class=InlineExecutor)
self.disagg_modes = []
def _submit_task(
self,
user_input,
overwrite_optimizer_data,
disagg_mode=None,
):
self.disagg_modes.append(disagg_mode)
summary = OptimizerSummary(overwrite_optimizer_data)
summary.set_summary_df(pd.DataFrame({"ttft": [100.0], "concurrency": [user_input.tp_size]}))
summary.set_memory_info(
{
"total_device_memory_gb": 64.0,
"reserved_memory_gb": 4.0,
"device_memory_available_gb": 8.0 / user_input.tp_size,
}
)
return summary
self.args.disagg = False
self.args.tp_sizes = [1, 2]
self.args.num_devices = 2
task_runner = RecordingParallelRunner(self.args)
result_df = task_runner._run_pd_phase(
devices_per_instance=self.args.prefill_devices_per_instance,
is_prefill=True,
)
self.assertFalse(self.args.disagg)
self.assertEqual(result_df.iloc[0]["ttft"], 100.0)
self.assertEqual(result_df.attrs["memory_info"]["total_device_memory_gb"], 64.0)
self.assertEqual(result_df.attrs["memory_info"]["device_memory_available_gb"], 4.0)
self.assertTrue(task_runner.disagg_modes)
self.assertTrue(all(task_runner.disagg_modes))
if __name__ == "__main__":
unittest.main()