import unittest
from unittest.mock import Mock, patch
import pandas as pd
from serving_cast.service.base_throughput_optimizer import BaseThroughputOptimizer
from serving_cast.service.optimizer_summary import OptimizerSummary
from serving_cast.service.utils import AGG_COLUMNS, OptimizerData
class ConcreteThroughputOptimizer(BaseThroughputOptimizer):
"""Concrete implementation of BaseThroughputOptimizer for testing purposes"""
def initialize(self, model):
self.model = model
def get_inference_info(self, optimizer_data):
summary = Mock(spec=OptimizerSummary)
summary.check_early_stop_flag.return_value = False
summary.get_summary_df.return_value = pd.DataFrame(columns=AGG_COLUMNS)
return summary
class TestBaseBackend(unittest.TestCase):
def setUp(self):
"""Set up test fixtures before each test method."""
self.backend = ConcreteThroughputOptimizer()
self.mock_args = Mock()
self.mock_args.batch_range = None
self.mock_model = Mock()
self.backend.initialize(self.mock_model)
self.mock_data_config = Mock()
self.mock_data_config.batch_size = 1
self.mock_data_config.input_length = 128
self.mock_data_config.output_length = 64
self.mock_data_config.ttft_limits = 1000
self.mock_data_config.tpot_limits = 200
def test_name_attribute(self):
"""Test that name attribute is set correctly"""
self.assertEqual(self.backend.name, "base")
@patch.object(ConcreteThroughputOptimizer, "get_inference_info")
def test_optimizer_basic(self, mock_get_inference_info):
"""Test optimizer with basic scenario"""
def side_effect(data_config):
summary = Mock(spec=OptimizerSummary)
if data_config.batch_size < 10:
summary.check_early_stop_flag.return_value = False
else:
summary.check_early_stop_flag.return_value = True
summary.get_summary_df.return_value = (
pd.DataFrame(columns=AGG_COLUMNS, data=[[None] * len(AGG_COLUMNS)])
if not summary.check_early_stop_flag.return_value
else pd.DataFrame(columns=AGG_COLUMNS)
)
if not summary.check_early_stop_flag.return_value:
mock_df = pd.DataFrame(
columns=AGG_COLUMNS,
data=[
[
"TEST_DEVICE",
1,
f"model_{data_config.batch_size}",
"DISABLED",
"DISABLED",
128,
64,
128,
8192,
1,
data_config.batch_size * 2,
100.0,
50.0,
1000.0,
500.0,
"tp1pp1dp1",
data_config.batch_size,
"prefill_breakdonws",
"decode_breakdowns",
]
],
)
summary.get_summary_df.return_value = mock_df
else:
summary.get_summary_df.return_value = pd.DataFrame(columns=AGG_COLUMNS)
return summary
mock_get_inference_info.side_effect = side_effect
result = self.backend.run(self.mock_data_config, [5, 20])
self.assertGreater(mock_get_inference_info.call_count, 0)
self.assertIsNotNone(result)
@patch.object(ConcreteThroughputOptimizer, "get_inference_info")
def test_optimizer_early_stop(self, mock_get_inference_info):
"""Test optimizer with early stop condition"""
mock_summary = Mock(spec=OptimizerSummary)
mock_summary.check_early_stop_flag.return_value = True
mock_get_inference_info.return_value = mock_summary
result = self.backend.run(self.mock_data_config, None)
self.assertIsNone(result)
@patch.object(ConcreteThroughputOptimizer, "get_inference_info")
def test_optimizer_no_results(self, mock_get_inference_info):
"""Test optimizer when no valid results found"""
def side_effect(data_config):
summary = Mock(spec=OptimizerSummary)
summary.check_early_stop_flag.return_value = True
summary.get_summary_df.return_value = pd.DataFrame(columns=AGG_COLUMNS)
return summary
mock_get_inference_info.side_effect = side_effect
_ = self.backend.run(self.mock_data_config, None)
mock_get_inference_info.assert_called()
def test_abstract_methods_exist(self):
"""Test that abstract methods exist"""
self.assertTrue(hasattr(BaseThroughputOptimizer, "initialize"))
self.assertTrue(hasattr(BaseThroughputOptimizer, "get_inference_info"))
def test_get_forward_info_uses_effective_input_length_for_prefill(self):
self.backend.model_runner = Mock()
self.backend.num_mtp_tokens = 0
optimizer_data = OptimizerData(
input_length=200,
output_length=64,
prefix_cache_hit_rate=0.5,
batch_size=1,
)
self.backend._get_forward_info(4, optimizer_data, is_decode=False)
requests = self.backend.model_runner.run_inference.call_args.args[0]
self.assertEqual(requests[0].query_len, 100)
self.assertEqual(requests[0].seq_len, 100)
def test_get_forward_info_keeps_original_input_length_for_decode(self):
self.backend.model_runner = Mock()
self.backend.num_mtp_tokens = 0
optimizer_data = OptimizerData(
input_length=200,
output_length=64,
prefix_cache_hit_rate=0.5,
batch_size=1,
)
self.backend._get_forward_info(4, optimizer_data, is_decode=True)
requests = self.backend.model_runner.run_inference.call_args.args[0]
self.assertEqual(requests[0].query_len, 1)
self.assertEqual(requests[0].seq_len, 233)
def test_resolve_forward_shape_uses_effective_prefill_by_default(self):
optimizer_data = OptimizerData(input_length=200, output_length=64, prefix_cache_hit_rate=0.5)
query_len, seq_len = self.backend._resolve_forward_shape(optimizer_data, is_decode=False)
self.assertEqual((query_len, seq_len), (100, 100))
def test_resolve_forward_shape_accepts_prefill_chunk_overrides(self):
optimizer_data = OptimizerData(input_length=200, output_length=64, prefix_cache_hit_rate=0.5)
query_len, seq_len = self.backend._resolve_forward_shape(
optimizer_data,
is_decode=False,
query_len=32,
seq_len=128,
)
self.assertEqual((query_len, seq_len), (32, 128))
def test_resolve_forward_shape_uses_original_prompt_for_decode(self):
self.backend.num_mtp_tokens = 2
optimizer_data = OptimizerData(input_length=200, output_length=64, prefix_cache_hit_rate=0.5)
query_len, seq_len = self.backend._resolve_forward_shape(optimizer_data, is_decode=True)
self.assertEqual((query_len, seq_len), (3, 235))
def test_resolve_forward_shape_accepts_decode_overrides(self):
self.backend.num_mtp_tokens = 2
optimizer_data = OptimizerData(input_length=200, output_length=64, prefix_cache_hit_rate=0.5)
query_len, seq_len = self.backend._resolve_forward_shape(
optimizer_data,
is_decode=True,
query_len=4,
seq_len=256,
)
self.assertEqual((query_len, seq_len), (4, 256))
def test_get_forward_info_uses_explicit_image_batch_size_when_provided(self):
self.backend.model_runner = Mock()
optimizer_data = OptimizerData(
input_length=32,
output_length=64,
batch_size=8,
image_batch_size=1,
image_height=1080,
image_width=1920,
)
self.backend._get_forward_info(8, optimizer_data, is_decode=False)
requests = self.backend.model_runner.run_inference.call_args.args[0]
self.assertEqual(requests[0].image_batch_size, 1)
def test_get_forward_info_falls_back_to_batch_size_for_image_batch_size(self):
self.backend.model_runner = Mock()
optimizer_data = OptimizerData(
input_length=32,
output_length=64,
batch_size=8,
image_batch_size=None,
image_height=1080,
image_width=1920,
)
self.backend._get_forward_info(8, optimizer_data, is_decode=False)
requests = self.backend.model_runner.run_inference.call_args.args[0]
self.assertEqual(requests[0].image_batch_size, 8)
def test_exponential_search_acc_search_clamps_right_boundary_on_early_stop(self):
optimizer_data = OptimizerData(
concurrency_search_strategy="linear_exponential", tpot_limits=50, ttft_limits=500
)
summary_left = Mock(spec=OptimizerSummary)
summary_left.get_search_info.return_value = {"tpot": 10.0, "ttft": 100.0}
summary_right = Mock(spec=OptimizerSummary)
summary_right.check_early_stop_flag.return_value = True
summary_right.get_search_info.return_value = {
"per_request_memory_gb": 2.0,
"device_memory_available_gb": -10.0,
"tpot": 60.0,
"ttft": 600.0,
}
with (
patch.object(self.backend, "get_inference_info", return_value=summary_right),
patch.object(self.backend, "_estimate_right_boundary", return_value=8) as mock_estimate,
):
left, right = self.backend._exponential_search(optimizer_data, 1, 512, summary_left, True)
self.assertEqual((left, right), (1, 8))
mock_estimate.assert_called_once()
def test_exponential_search_acc_search_uses_estimated_boundary_before_stop(self):
optimizer_data = OptimizerData(
concurrency_search_strategy="linear_exponential", tpot_limits=50, ttft_limits=500
)
summary_left = Mock(spec=OptimizerSummary)
summary_left.get_search_info.return_value = {"tpot": 10.0, "ttft": 100.0}
summary_right = Mock(spec=OptimizerSummary)
summary_right.check_early_stop_flag.return_value = False
summary_right.get_search_info.return_value = {
"per_request_memory_gb": 0.5,
"device_memory_available_gb": 10.0,
"tpot": 40.0,
"ttft": 400.0,
}
with (
patch.object(self.backend, "get_inference_info", return_value=summary_right),
patch.object(self.backend, "_estimate_right_boundary", return_value=530),
):
left, right = self.backend._exponential_search(optimizer_data, 1, 512, summary_left, True)
self.assertEqual((left, right), (1, 530))
def test_compute_per_request_memory_gb_handles_zero_and_positive_batch_size(self):
self.assertEqual(
self.backend._compute_per_request_memory_gb(
total_device_memory_gb=64,
model_weight_size_gb=20,
reserved_memory_gb=10,
memory_left_gb=12,
batch_size=0,
),
0,
)
self.assertEqual(
self.backend._compute_per_request_memory_gb(
total_device_memory_gb=64,
model_weight_size_gb=20,
reserved_memory_gb=10,
memory_left_gb=12,
batch_size=4,
),
5.5,
)
def test_estimate_right_boundary_falls_back_to_max_search_size(self):
optimizer_data = OptimizerData(tpot_limits=None, ttft_limits=None)
estimated = self.backend._estimate_right_boundary(
{"batch_size": 1},
{"batch_size": 8, "per_request_memory_gb": 0, "device_memory_available_gb": 0},
optimizer_data,
)
self.assertEqual(estimated, 2**19 - 1)
def test_estimate_right_boundary_uses_memory_limit_and_skips_fallback(self):
optimizer_data = OptimizerData(tpot_limits=None, ttft_limits=None)
estimated = self.backend._estimate_right_boundary(
{"batch_size": 1},
{
"batch_size": 8,
"per_request_memory_gb": 2.0,
"device_memory_available_gb": 10.0,
},
optimizer_data,
)
self.assertEqual(estimated, 14)
def test_exponential_search_without_acc_search_doubles_until_max_iterations(self):
optimizer_data = OptimizerData(concurrency_search_strategy="exponential")
summary_left = Mock(spec=OptimizerSummary)
summary_right = Mock(spec=OptimizerSummary)
summary_right.check_early_stop_flag.return_value = False
with patch.object(self.backend, "get_inference_info", return_value=summary_right) as mock_get_inference_info:
left, right = self.backend._exponential_search(optimizer_data, 1, 2, summary_left)
self.assertEqual((left, right), (1024, 2048))
self.assertEqual(mock_get_inference_info.call_count, 10)
def test_estimate_by_latency_returns_relaxed_boundary_when_latency_grows(self):
estimated = self.backend._estimate_by_latency(
bs_left=2,
bs_right=6,
lat_left=10.0,
lat_right=30.0,
lat_limit=20.0,
relax_factor=1.5,
estimated_right=999,
)
self.assertEqual(estimated, 7)
if __name__ == "__main__":
unittest.main()