import argparse
import unittest
from serving_cast.service.utils import (
BatchRangeAction,
OptimizerData,
PrefillChunk,
check_positive_float,
check_positive_integer,
check_string_valid,
)
class TestServiceUtils(unittest.TestCase):
def test_check_string_valid_within_limit_and_valid_chars(self):
"""Test check_string_valid with valid string"""
valid_string = "valid_string123/test-path.file"
result = check_string_valid(valid_string, max_len=100)
self.assertEqual(result, valid_string)
def test_check_positive_integer_valid(self):
"""Test check_positive_integer with valid integers"""
self.assertEqual(check_positive_integer("1"), 1)
self.assertEqual(check_positive_integer("100"), 100)
self.assertEqual(check_positive_integer(5), 5)
def test_check_positive_integer_invalid_string(self):
"""Test check_positive_integer with invalid string"""
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_integer("abc")
def test_check_positive_integer_non_positive(self):
"""Test check_positive_integer with non-positive values"""
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_integer("0")
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_integer("-1")
def test_check_positive_integer_too_large(self):
"""Test check_positive_integer with very large value"""
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_integer("2000000")
def test_check_positive_float_valid(self):
"""Test check_positive_float with valid floats"""
self.assertEqual(check_positive_float("1.5"), 1.5)
self.assertEqual(check_positive_float("100"), 100.0)
self.assertEqual(check_positive_float("inf"), float("inf"))
self.assertEqual(check_positive_float("INF"), float("inf"))
def test_check_positive_float_invalid(self):
"""Test check_positive_float with invalid values"""
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_float("abc")
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_float("0")
with self.assertRaises(argparse.ArgumentTypeError):
check_positive_float("-1.5")
def test_optimizer_data_creation(self):
"""Test OptimizerData creation with default values"""
config = OptimizerData()
self.assertIsNone(config.input_length)
self.assertIsNone(config.output_length)
self.assertEqual(config.prefix_cache_hit_rate, 0.0)
def test_optimizer_data_effective_input_length_with_prefix_cache(self):
config = OptimizerData(input_length=200, prefix_cache_hit_rate=0.5)
self.assertEqual(config.get_effective_input_length(), 100)
def test_optimizer_data_effective_input_length_ignores_prefix_cache_in_decode(self):
config = OptimizerData(input_length=200, prefix_cache_hit_rate=0.5)
self.assertEqual(config.get_effective_input_length(is_decode=True), 200)
def test_optimizer_data_prefill_chunk_plan_single_chunk(self):
config = OptimizerData(input_length=4096, max_batched_tokens=8192)
self.assertEqual(
config.get_prefill_chunk_plan(),
[PrefillChunk(index=0, query_len=4096, seq_len=4096)],
)
def test_optimizer_data_prefill_chunk_plan_multiple_chunks(self):
config = OptimizerData(input_length=10000, max_batched_tokens=4096)
self.assertEqual(
config.get_prefill_chunk_plan(),
[
PrefillChunk(index=0, query_len=4096, seq_len=4096),
PrefillChunk(index=1, query_len=4096, seq_len=8192),
PrefillChunk(index=2, query_len=1808, seq_len=10000),
],
)
def test_optimizer_data_prefill_chunk_plan_applies_prefix_cache(self):
config = OptimizerData(input_length=10, max_batched_tokens=3, prefix_cache_hit_rate=0.5)
self.assertEqual(
config.get_prefill_chunk_plan(),
[
PrefillChunk(index=0, query_len=3, seq_len=3),
PrefillChunk(index=1, query_len=2, seq_len=5),
],
)
def test_optimizer_data_prefill_chunk_plan_returns_empty_without_input_length(self):
config = OptimizerData(max_batched_tokens=None)
self.assertEqual(config.get_prefill_chunk_plan(), [])
def test_optimizer_data_prefill_chunk_plan_rejects_invalid_token_budget(self):
for max_batched_tokens in (None, 0, -1):
with self.subTest(max_batched_tokens=max_batched_tokens):
config = OptimizerData(input_length=10, max_batched_tokens=max_batched_tokens)
with self.assertRaises(ValueError):
config.get_prefill_chunk_plan()
def test_optimizer_data_prefill_num_chunks_matches_chunk_plan(self):
config = OptimizerData(input_length=9, max_batched_tokens=4)
self.assertEqual(config.get_prefill_num_chunks(), 3)
class TestBatchRangeAction(unittest.TestCase):
"""Test BatchRangeAction class functionality"""
def setUp(self):
"""Set up test fixtures before each test method."""
self.parser = argparse.ArgumentParser()
self.namespace = argparse.Namespace()
self.action = BatchRangeAction(option_strings=["--batch-range"], dest="batch_range")
def test_valid_single_value(self):
"""Test BatchRangeAction with valid single value"""
parser = argparse.ArgumentParser()
namespace = argparse.Namespace()
self.action(parser, namespace, [100])
self.assertEqual(namespace.batch_range, [100])
def test_valid_range_values(self):
"""Test BatchRangeAction with valid range values"""
parser = argparse.ArgumentParser()
namespace = argparse.Namespace()
self.action(parser, namespace, [10, 100])
self.assertEqual(namespace.batch_range, [10, 100])
def test_invalid_range_order(self):
"""Test BatchRangeAction with invalid range order"""
parser = argparse.ArgumentParser()
namespace = argparse.Namespace()
with self.assertRaises(argparse.ArgumentTypeError):
self.action(parser, namespace, [100, 10])