import multiprocessing as mp
import tempfile
import threading
import unittest
from dataclasses import dataclass
import numpy as np
from serving_cast.config import Config, ParallelConfig
from serving_cast.model_runner import (
AsyncTask,
AsyncTaskManager,
CompletionEventManager,
InterpolationPoint,
ModelRunner,
ModelRunnerMetricCacheManager,
)
from serving_cast.request import Request, RequestState
from tensor_cast.core.input_generator import RequestInfo
from tensor_cast.core.model_runner import ModelRunner as TensorCastModelRunner
from tensor_cast.core.model_runner import ModelRunnerMetrics
from tensor_cast.core.quantization.datatypes import QuantizeAttentionAction, QuantizeLinearAction
from tensor_cast.core.user_config import UserInputConfig
@dataclass
class MockParsedArgs:
"""Mock parsed args for Config initialization."""
instance_config_path: str
common_config_path: str
enable_profiling: bool = False
def create_test_config_files():
"""Create temporary config files for testing."""
import os
import yaml
tmp_dir = tempfile.mkdtemp()
common_config = {
"model_config": {
"name": "Qwen/Qwen3-32B",
"enable_multi_process": False,
"enable_interpolate": False,
},
"load_gen": {
"load_gen_type": "poisson",
"num_requests": 10,
"num_input_tokens": 100,
"num_output_tokens": 50,
"request_rate": 1.0,
},
"serving_config": {
"max_concurrency": 100,
"block_size": 128,
"max_tokens_budget": 8192,
},
}
instance_config = {
"instance_groups": [
{
"num_instances": 1,
"num_devices_per_instance": 1,
"pd_role": "both",
"parallel_config": {
"world_size": 1,
"tp_size": 1,
},
}
]
}
common_path = os.path.join(tmp_dir, "common.yaml")
instance_path = os.path.join(tmp_dir, "instances.yaml")
with open(common_path, "w", encoding="utf-8") as f:
yaml.dump(common_config, f)
with open(instance_path, "w", encoding="utf-8") as f:
yaml.dump(instance_config, f)
return tmp_dir, common_path, instance_path
class TestTensorCastModelRunner(unittest.TestCase):
def test_init_valid_device(self):
runner = TensorCastModelRunner(
UserInputConfig(
device="TEST_DEVICE",
model_id="Qwen/Qwen3-32B",
world_size=1,
tp_size=1,
)
)
self.assertIsNotNone(runner.model)
self.assertEqual(runner.model.model_config.parallel_config.world_size, 1)
self.assertEqual(runner.model.model_config.parallel_config.tensor_parallel_size, 1)
self.assertIsNotNone(runner.model.model_config.quant_config)
def test_init_invalid_device(self):
with self.assertRaises(ValueError):
TensorCastModelRunner(
UserInputConfig(
device="invalid-device",
model_id="test-model",
world_size=1,
tp_size=1,
)
)
def test_run_inference_basic(self):
mock_requests: list[RequestInfo] = [
RequestInfo(query_len=10, seq_len=10, is_decode=False),
RequestInfo(query_len=1, seq_len=10, is_decode=True),
]
runner = TensorCastModelRunner(
UserInputConfig(
device="TEST_DEVICE",
model_id="Qwen/Qwen3-32B",
)
)
metrics = runner.run_inference(mock_requests)
self.assertIsNotNone(metrics)
def test_run_inference_with_ep(self):
model_runner = TensorCastModelRunner(
UserInputConfig(
device="TEST_DEVICE",
model_id="deepseek-ai/DeepSeek-V3.1",
quantize_linear_action=QuantizeLinearAction.FP8,
quantize_attention_action=QuantizeAttentionAction.INT8,
world_size=8,
tp_size=8,
dp_size=1,
ep_size=8,
)
)
requests = [RequestInfo(1, 65, True)]
metrics = model_runner.run_inference(requests)
self.assertIsNotNone(metrics)
def test_check_peak_memory_usage_gb_scales_heterogeneous_prefill_batch(self):
runner = TensorCastModelRunner.__new__(TensorCastModelRunner)
runner.model_weight_size_gb = 40.0
runner.total_device_memory_gb = 64.0
runner.user_input = UserInputConfig(
device="TEST_DEVICE",
model_id="Qwen/Qwen3-32B",
reserved_memory_gb=10.0,
)
peak_memory_usage_gb = runner._check_peak_memory_usage_gb(
peak_memory_usage_gb=80.0,
kv_cache_size_gb=5.0,
requests=[
RequestInfo(query_len=250, seq_len=250, is_decode=False),
RequestInfo(query_len=1000, seq_len=1000, is_decode=False),
],
)
self.assertEqual(peak_memory_usage_gb, 48.5)
def test_check_peak_memory_usage_gb_keeps_homogeneous_prefill_batch(self):
runner = TensorCastModelRunner.__new__(TensorCastModelRunner)
runner.model_weight_size_gb = 40.0
runner.total_device_memory_gb = 64.0
runner.user_input = UserInputConfig(
device="TEST_DEVICE",
model_id="Qwen/Qwen3-32B",
reserved_memory_gb=10.0,
)
peak_memory_usage_gb = runner._check_peak_memory_usage_gb(
peak_memory_usage_gb=80.0,
kv_cache_size_gb=5.0,
requests=[
RequestInfo(query_len=250, seq_len=250, is_decode=False),
RequestInfo(query_len=250, seq_len=250, is_decode=False),
],
)
self.assertEqual(peak_memory_usage_gb, 80.0)
def test_check_peak_memory_usage_gb_clamps_negative_activation_for_heterogeneous_prefill_batch(self):
runner = TensorCastModelRunner.__new__(TensorCastModelRunner)
runner.model_weight_size_gb = 40.0
runner.total_device_memory_gb = 50.0
runner.user_input = UserInputConfig(
device="TEST_DEVICE",
model_id="Qwen/Qwen3-32B",
reserved_memory_gb=10.0,
)
peak_memory_usage_gb = runner._check_peak_memory_usage_gb(
peak_memory_usage_gb=42.0,
kv_cache_size_gb=5.0,
requests=[
RequestInfo(query_len=250, seq_len=250, is_decode=False),
RequestInfo(query_len=1000, seq_len=1000, is_decode=False),
],
)
self.assertEqual(peak_memory_usage_gb, 45.0)
class TestInterpolationPoint(unittest.TestCase):
"""Tests for InterpolationPoint dataclass."""
def test_interpolation_point_creation(self):
"""Test InterpolationPoint dataclass creation."""
point = InterpolationPoint(total_seq_len=100, total_query_len=50)
self.assertEqual(point.total_seq_len, 100)
self.assertEqual(point.total_query_len, 50)
def test_interpolation_point_equality(self):
"""Test InterpolationPoint equality."""
point1 = InterpolationPoint(total_seq_len=100, total_query_len=50)
point2 = InterpolationPoint(total_seq_len=100, total_query_len=50)
point3 = InterpolationPoint(total_seq_len=200, total_query_len=50)
self.assertEqual(point1, point2)
self.assertNotEqual(point1, point3)
class TestAsyncTask(unittest.TestCase):
"""Tests for AsyncTask class."""
def test_async_task_creation(self):
"""Test AsyncTask creation."""
batch = [
RequestInfo(query_len=10, seq_len=10, is_decode=False),
RequestInfo(query_len=1, seq_len=20, is_decode=True),
]
task = AsyncTask(batch)
self.assertEqual(task.batch, batch)
self.assertIsNotNone(task.hash_value)
def test_async_task_hash_consistency(self):
"""Test that same batch produces same hash."""
batch = [
RequestInfo(query_len=10, seq_len=10, is_decode=False),
]
task1 = AsyncTask(batch)
task2 = AsyncTask(batch)
self.assertEqual(task1.hash_value, task2.hash_value)
def test_async_task_hash_different(self):
"""Test that different batches produce different hashes."""
batch1 = [RequestInfo(query_len=10, seq_len=10, is_decode=False)]
batch2 = [RequestInfo(query_len=20, seq_len=20, is_decode=False)]
task1 = AsyncTask(batch1)
task2 = AsyncTask(batch2)
self.assertNotEqual(task1.hash_value, task2.hash_value)
def test_async_task_get_hash(self):
"""Test get_hash method."""
batch = [RequestInfo(query_len=10, seq_len=100, is_decode=False)]
task = AsyncTask(batch)
hash_value = task.get_hash()
self.assertEqual(hash_value, task.hash_value)
class TestModelRunnerStaticMethods(unittest.TestCase):
"""Tests for ModelRunner static methods."""
def test_get_interpolation_point(self):
"""Test get_interpolation_point static method."""
batch = [
RequestInfo(query_len=10, seq_len=100, is_decode=False),
RequestInfo(query_len=1, seq_len=50, is_decode=True),
]
point = ModelRunner.get_interpolation_point(batch)
self.assertEqual(point.total_seq_len, 150)
self.assertEqual(point.total_query_len, 11)
def test_get_interpolation_point_empty(self):
"""Test get_interpolation_point with empty batch."""
batch = []
point = ModelRunner.get_interpolation_point(batch)
self.assertEqual(point.total_seq_len, 0)
self.assertEqual(point.total_query_len, 0)
def test_get_interpolation_point_single(self):
"""Test get_interpolation_point with single request."""
batch = [RequestInfo(query_len=100, seq_len=500, is_decode=False)]
point = ModelRunner.get_interpolation_point(batch)
self.assertEqual(point.total_seq_len, 500)
self.assertEqual(point.total_query_len, 100)
def test_predict_next_batch_prefill(self):
"""Test predict_next_batch for prefill request."""
current_batch = [
RequestInfo(
query_len=10,
seq_len=5,
num_input_tokens=100,
num_output_tokens=50,
is_decode=False,
)
]
future_batch = ModelRunner.predict_next_batch(current_batch)
self.assertEqual(len(future_batch), 1)
self.assertEqual(future_batch[0].query_len, 90)
self.assertEqual(future_batch[0].seq_len, 5)
self.assertFalse(future_batch[0].is_decode)
def test_predict_next_batch_decode(self):
"""Test predict_next_batch for decode request."""
current_batch = [
RequestInfo(
query_len=1,
seq_len=100,
num_input_tokens=100,
num_output_tokens=50,
is_decode=True,
)
]
future_batch = ModelRunner.predict_next_batch(current_batch)
self.assertEqual(len(future_batch), 1)
self.assertEqual(future_batch[0].query_len, 1)
self.assertEqual(future_batch[0].seq_len, 101)
self.assertTrue(future_batch[0].is_decode)
def test_predict_next_batch_finished(self):
"""Test predict_next_batch for finished request."""
current_batch = [
RequestInfo(
query_len=1,
seq_len=149,
num_input_tokens=100,
num_output_tokens=50,
is_decode=True,
)
]
future_batch = ModelRunner.predict_next_batch(current_batch)
self.assertEqual(len(future_batch), 0)
def test_predict_next_batch_invalid_seq_len(self):
"""Test predict_next_batch with invalid seq_len raises error."""
current_batch = [
RequestInfo(
query_len=1,
seq_len=200,
num_input_tokens=100,
num_output_tokens=50,
is_decode=True,
)
]
with self.assertRaises(ValueError):
ModelRunner.predict_next_batch(current_batch)
def test_predict_next_batch_multiple_requests(self):
"""Test predict_next_batch with multiple requests."""
current_batch = [
RequestInfo(
query_len=10,
seq_len=5,
num_input_tokens=100,
num_output_tokens=50,
is_decode=False,
),
RequestInfo(
query_len=1,
seq_len=120,
num_input_tokens=100,
num_output_tokens=50,
is_decode=True,
),
]
future_batch = ModelRunner.predict_next_batch(current_batch)
self.assertEqual(len(future_batch), 2)
def test_request2info_prefill(self):
"""Test request2info with prefill request."""
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.PREFILLING
request.query_len = 10
request.seq_len = 10
request_infos = ModelRunner.request2info([request])
self.assertEqual(len(request_infos), 1)
self.assertEqual(request_infos[0].query_len, 10)
self.assertEqual(request_infos[0].seq_len, 10)
self.assertFalse(request_infos[0].is_decode)
def test_request2info_decode(self):
"""Test request2info with decode request."""
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.DECODING
request.query_len = 1
request.seq_len = 150
request_infos = ModelRunner.request2info([request])
self.assertEqual(len(request_infos), 1)
self.assertEqual(request_infos[0].query_len, 1)
self.assertEqual(request_infos[0].seq_len, 150)
self.assertTrue(request_infos[0].is_decode)
def test_request2info_recomputation(self):
"""Test request2info with recomputation request."""
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.RECOMPUTATION
request.query_len = 10
request.seq_len = 10
request_infos = ModelRunner.request2info([request])
self.assertEqual(len(request_infos), 1)
self.assertFalse(request_infos[0].is_decode)
def test_request2info_invalid_state(self):
"""Test request2info with invalid state raises error."""
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.INITIAL
request.query_len = 10
request.seq_len = 10
with self.assertRaises(ValueError):
ModelRunner.request2info([request])
def test_request2info_query_gt_seq(self):
"""Test request2info with query_len > seq_len raises error."""
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.PREFILLING
request.query_len = 20
request.seq_len = 10
with self.assertRaises(ValueError):
ModelRunner.request2info([request])
def test_request2info_multiple_requests(self):
"""Test request2info with multiple requests."""
request1 = Request(num_input_tokens=100, num_output_tokens=50)
request1.state = RequestState.PREFILLING
request1.query_len = 10
request1.seq_len = 10
request2 = Request(num_input_tokens=200, num_output_tokens=100)
request2.state = RequestState.DECODING
request2.query_len = 1
request2.seq_len = 250
request_infos = ModelRunner.request2info([request1, request2])
self.assertEqual(len(request_infos), 2)
def test_get_interpolation_model_basic(self):
"""Test get_interpolation_model static method."""
x = np.array([[0, 0], [1, 0], [0, 1]])
y = np.array([1.0, 2.0, 3.0])
model = ModelRunner.get_interpolation_model(x, y)
result = model([0.33, 0.33])
self.assertIsNotNone(result)
def test_get_interpolation_model_invalid_x_shape(self):
"""Test get_interpolation_model with invalid x shape."""
x = np.array([1, 2, 3])
y = np.array([1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
ModelRunner.get_interpolation_model(x, y)
def test_get_interpolation_model_invalid_y_shape(self):
"""Test get_interpolation_model with invalid y shape."""
x = np.array([[1, 1], [2, 2], [3, 3]])
y = np.array([[1.0], [2.0], [3.0]])
with self.assertRaises(ValueError):
ModelRunner.get_interpolation_model(x, y)
def test_get_interpolation_model_mismatched_lengths(self):
"""Test get_interpolation_model with mismatched lengths."""
x = np.array([[1, 1], [2, 2], [3, 3]])
y = np.array([1.0, 2.0])
with self.assertRaises(ValueError):
ModelRunner.get_interpolation_model(x, y)
def test_get_interpolation_model_multiple_points(self):
"""Test get_interpolation_model predict function with multiple points."""
x = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
y = np.array([1.0, 2.0, 3.0, 4.0])
model = ModelRunner.get_interpolation_model(x, y)
result = model([[0.5, 0.5], [0.5, 0.5]])
self.assertEqual(len(result), 2)
def test_get_interpolation_model_single_point_invalid(self):
"""Test get_interpolation_model predict with invalid single point."""
x = np.array([[0, 0], [1, 0], [0, 1]])
y = np.array([1.0, 2.0, 3.0])
model = ModelRunner.get_interpolation_model(x, y)
with self.assertRaises(ValueError):
model([1, 2, 3])
def test_get_interpolation_model_multiple_points_invalid_shape(self):
"""Test get_interpolation_model predict with invalid multiple points shape."""
x = np.array([[0, 0], [1, 0], [0, 1]])
y = np.array([1.0, 2.0, 3.0])
model = ModelRunner.get_interpolation_model(x, y)
with self.assertRaises(ValueError):
model([[1, 2, 3], [4, 5, 6]])
class TestModelRunnerMetricCacheManager(unittest.TestCase):
"""Tests for ModelRunnerMetricCacheManager class."""
def setUp(self):
"""Set up test fixtures with real multiprocessing Manager."""
self.manager = mp.Manager()
def tearDown(self):
"""Clean up multiprocessing Manager."""
self.manager.shutdown()
def test_init_cache_slot(self):
"""Test init_cache_slot method."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
cache_manager.init_cache_slot("test_id")
self.assertIn("test_id", cache_manager.cache)
def test_init_cache_slot_duplicate(self):
"""Test init_cache_slot with duplicate cache_id raises error."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
cache_manager.init_cache_slot("test_id")
with self.assertRaises(ValueError):
cache_manager.init_cache_slot("test_id")
def test_get_cache(self):
"""Test get_cache method."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
cache_manager.init_cache_slot("test_id")
cache_manager.cache["test_id"] = "test_value"
result = cache_manager.get_cache("test_id")
self.assertEqual(result, "test_value")
def test_get_cache_not_found(self):
"""Test get_cache with non-existent cache_id raises error."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
with self.assertRaises(KeyError):
cache_manager.get_cache("non_existent")
def _create_test_metrics(self):
"""Helper to create a valid ModelRunnerMetrics instance for testing."""
return ModelRunnerMetrics(
total_device_memory_gb=80.0,
model_weight_size_gb=15.0,
peak_memory_usage_gb=50.0,
kv_cache_size_gb=5.0,
kv_cache_per_token_gb=0.001,
model_activation_size_gb=10.0,
reserved_memory_gb=0.0,
device_memory_available_gb=10.0,
execution_time_s={"analytic": 0.5},
tps_per_model={"analytic": 100.0},
run_time_s=1.0,
batch_size=1,
)
def test_record_cache(self):
"""Test record_cache method."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
cache_manager.init_cache_slot("test_id")
test_metrics = self._create_test_metrics()
cache_manager.record_cache("test_id", test_metrics)
self.assertEqual(cache_manager.cache["test_id"], test_metrics)
def test_record_cache_not_found(self):
"""Test record_cache with non-existent cache_id raises error."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
test_metrics = self._create_test_metrics()
with self.assertRaises(KeyError):
cache_manager.record_cache("non_existent", test_metrics)
def test_cache_round_trip(self):
"""Test storing and retrieving metrics."""
cache_manager = ModelRunnerMetricCacheManager(self.manager)
cache_manager.init_cache_slot("metrics_id")
original_metrics = ModelRunnerMetrics(
total_device_memory_gb=80.0,
model_weight_size_gb=15.0,
peak_memory_usage_gb=50.0,
kv_cache_size_gb=20.0,
kv_cache_per_token_gb=0.0005,
model_activation_size_gb=10.0,
reserved_memory_gb=0.0,
device_memory_available_gb=80.0,
execution_time_s={"analytic": 1.23, "empirical": 1.45},
tps_per_model={"analytic": 100.0},
run_time_s=2.0,
batch_size=2,
)
cache_manager.record_cache("metrics_id", original_metrics)
retrieved = cache_manager.get_cache("metrics_id")
self.assertEqual(retrieved.execution_time_s["analytic"], 1.23)
self.assertEqual(retrieved.device_memory_available_gb, 80.0)
class TestCompletionEventManager(unittest.TestCase):
"""Tests for CompletionEventManager class."""
def setUp(self):
"""Set up test fixtures with real multiprocessing Manager."""
self.manager = mp.Manager()
def tearDown(self):
"""Clean up multiprocessing Manager."""
self.manager.shutdown()
def test_init_event_slot(self):
"""Test init_event_slot method."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("test_event")
self.assertIn("test_event", event_manager.event_dict)
event_manager.shutdown()
def test_init_event_slot_duplicate(self):
"""Test init_event_slot with duplicate event_id raises error."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("test_event")
with self.assertRaises(ValueError):
event_manager.init_event_slot("test_event")
event_manager.shutdown()
def test_set_completion_event(self):
"""Test set_completion_event method."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("test_event")
event_manager.set_completion_event("test_event")
import time
time.sleep(0.5)
self.assertTrue(event_manager.event_dict["test_event"].is_set())
event_manager.shutdown()
def test_wait_completion_event(self):
"""Test wait_completion_event method."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("test_event")
def set_event():
import time
time.sleep(0.1)
event_manager.set_completion_event("test_event")
setter_thread = threading.Thread(target=set_event)
setter_thread.start()
event_manager.wait_completion_event("test_event")
self.assertTrue(event_manager.event_dict["test_event"].is_set())
setter_thread.join()
event_manager.shutdown()
def test_shutdown(self):
"""Test shutdown method."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("test_event")
event_manager.shutdown()
self.assertFalse(event_manager._thread_running)
def test_shutdown_with_empty_queue(self):
"""Test shutdown with empty queue."""
event_manager = CompletionEventManager(self.manager)
event_manager.shutdown()
self.assertFalse(event_manager._thread_running)
def test_shutdown_clears_event_dict(self):
"""Test that shutdown clears the event dictionary."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("event1")
event_manager.init_event_slot("event2")
event_manager.shutdown()
self.assertEqual(len(event_manager.event_dict), 0)
def test_multiple_events(self):
"""Test handling multiple events."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("event1")
event_manager.init_event_slot("event2")
event_manager.set_completion_event("event1")
event_manager.set_completion_event("event2")
import time
time.sleep(0.5)
self.assertTrue(event_manager.event_dict["event1"].is_set())
self.assertTrue(event_manager.event_dict["event2"].is_set())
event_manager.shutdown()
class TestCompletionEventManagerThread(unittest.TestCase):
"""Tests for CompletionEventManager background thread behavior."""
def setUp(self):
"""Set up test fixtures."""
self.manager = mp.Manager()
def tearDown(self):
"""Clean up."""
self.manager.shutdown()
def test_thread_running_after_init(self):
"""Test that background thread is running after initialization."""
event_manager = CompletionEventManager(self.manager)
self.assertTrue(event_manager._thread_running)
self.assertTrue(event_manager._event_thread.is_alive())
event_manager.shutdown()
def test_thread_stops_after_shutdown(self):
"""Test that background thread stops after shutdown."""
event_manager = CompletionEventManager(self.manager)
event_manager.shutdown()
event_manager._event_thread.join(timeout=2)
self.assertFalse(event_manager._event_thread.is_alive())
class TestProcessCompletionQueue(unittest.TestCase):
"""Tests for _process_completion_queue edge cases."""
def setUp(self):
"""Set up test fixtures."""
self.manager = mp.Manager()
def tearDown(self):
"""Clean up."""
self.manager.shutdown()
def test_process_queue_with_none_event_id(self):
"""Test that None event_id is skipped in queue processing."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("real_event")
event_manager.completion_queue.put(None)
import time
time.sleep(0.5)
self.assertFalse(event_manager.event_dict["real_event"].is_set())
event_manager.shutdown()
def test_process_queue_sets_event(self):
"""Test that event is set when processing queue."""
event_manager = CompletionEventManager(self.manager)
event_manager.init_event_slot("test_event")
event_manager.completion_queue.put("test_event")
import time
time.sleep(0.5)
self.assertTrue(event_manager.event_dict["test_event"].is_set())
event_manager.shutdown()
def test_process_queue_unknown_event_raises(self):
"""Test that unknown event_id raises ValueError."""
event_manager = CompletionEventManager(self.manager)
event_manager.completion_queue.put("unknown_event")
import time
time.sleep(1)
event_manager.shutdown()
class TestAsyncTaskManager(unittest.TestCase):
"""Tests for AsyncTaskManager class."""
def test_add_task(self):
"""Test add_task method."""
batch = [RequestInfo(query_len=10, seq_len=100, is_decode=False)]
task = AsyncTask(batch)
manager = mp.Manager()
task_queue = manager.Queue()
cache_manager = ModelRunnerMetricCacheManager(manager)
event_manager = CompletionEventManager(manager)
task_hash = task.hash_value
cache_manager.init_cache_slot(task_hash)
event_manager.init_event_slot(task_hash)
task_queue.put(task)
self.assertIn(task_hash, cache_manager.cache)
self.assertIn(task_hash, event_manager.event_dict)
event_manager.shutdown()
manager.shutdown()
class TestModelRunnerIntegration(unittest.TestCase):
"""Integration tests for ModelRunner with Config."""
@classmethod
def setUpClass(cls):
"""Set up Config for ModelRunner tests."""
cls.tmp_dir, cls.common_path, cls.instance_path = create_test_config_files()
Config._instance = None
Config._initialized = False
cls.config = Config(
MockParsedArgs(
instance_config_path=cls.instance_path,
common_config_path=cls.common_path,
)
)
@classmethod
def tearDownClass(cls):
"""Clean up temp files."""
import shutil
Config._instance = None
Config._initialized = False
shutil.rmtree(cls.tmp_dir, ignore_errors=True)
def test_init_tensor_cast_model_runner(self):
"""Test init_tensor_cast_model_runner static method."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner.init_tensor_cast_model_runner(self.config.common_config, parallel_config, "TEST_DEVICE")
self.assertIsNotNone(runner)
def test_model_runner_init(self):
"""Test ModelRunner initialization."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
self.assertIsNotNone(runner.tensor_cast_model_runner)
self.assertFalse(runner.enable_multi_process)
runner.shutdown()
def test_model_runner_get_kv_cache_num_bytes(self):
"""Test get_kv_cache_num_bytes method."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
num_bytes = runner.get_kv_cache_num_bytes(100)
self.assertIsNotNone(num_bytes)
runner.shutdown()
def test_model_runner_get_inputs_num_bytes(self):
"""Test get_inputs_num_bytes method."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.PREFILLING
request.query_len = 10
request.seq_len = 10
num_bytes = runner.get_inputs_num_bytes([request])
self.assertIsInstance(num_bytes, int)
runner.shutdown()
def test_model_runner_process_batch(self):
"""Test process_batch method."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.PREFILLING
request.query_len = 10
request.seq_len = 10
runner.process_batch([request])
runner.shutdown()
def test_model_runner_warmup(self):
"""Test warmup method."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
num_blocks, block_size = runner.warmup()
self.assertIsInstance(num_blocks, int)
self.assertIsInstance(block_size, int)
runner.shutdown()
def test_apply_interpolation_model_not_ready(self):
"""Test apply_interpolation_model raises when not ready."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
runner._interpolation_ready = False
runner._interpolation_model = None
batch = [RequestInfo(query_len=10, seq_len=100, is_decode=False)]
with self.assertRaises(ValueError):
runner.apply_interpolation_model(batch)
runner.shutdown()
def test_apply_interpolation_model_with_model(self):
"""Test apply_interpolation_model with a real interpolation model."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
x = np.array([[0, 0], [1000, 0], [0, 1000], [1000, 1000]])
y = np.array([0.1, 0.5, 0.3, 0.8])
runner._interpolation_model = ModelRunner.get_interpolation_model(x, y)
runner._interpolation_ready = True
batch = [RequestInfo(query_len=100, seq_len=500, is_decode=False)]
result = runner.apply_interpolation_model(batch)
self.assertIsInstance(result, float)
runner.shutdown()
class TestModelRunnerWithInterpolation(unittest.TestCase):
"""Tests for ModelRunner with interpolation enabled."""
@classmethod
def setUpClass(cls):
"""Set up Config with enable_interpolate=True."""
import os
import yaml
tmp_dir = tempfile.mkdtemp()
common_config = {
"model_config": {
"name": "Qwen/Qwen3-32B",
"enable_multi_process": False,
"enable_interpolate": True,
"interpolation_seed": 42,
},
"load_gen": {
"load_gen_type": "poisson",
"num_requests": 10,
"num_input_tokens": 100,
"num_output_tokens": 50,
"request_rate": 1.0,
},
"serving_config": {
"max_concurrency": 10,
"block_size": 128,
"max_tokens_budget": 512,
},
}
instance_config = {
"instance_groups": [
{
"num_instances": 1,
"num_devices_per_instance": 1,
"pd_role": "both",
"parallel_config": {
"world_size": 1,
"tp_size": 1,
},
}
]
}
common_path = os.path.join(tmp_dir, "common.yaml")
instance_path = os.path.join(tmp_dir, "instances.yaml")
with open(common_path, "w", encoding="utf-8") as f:
yaml.dump(common_config, f)
with open(instance_path, "w", encoding="utf-8") as f:
yaml.dump(instance_config, f)
cls.tmp_dir = tmp_dir
Config._instance = None
Config._initialized = False
cls.config = Config(
MockParsedArgs(
instance_config_path=instance_path,
common_config_path=common_path,
)
)
@classmethod
def tearDownClass(cls):
"""Clean up temp files."""
import shutil
Config._instance = None
Config._initialized = False
shutil.rmtree(cls.tmp_dir, ignore_errors=True)
def test_model_runner_init_with_interpolate(self):
"""Test ModelRunner initialization with interpolation enabled."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
self.assertTrue(runner.enable_interpolate)
self.assertTrue(runner._interpolation_ready)
self.assertIsNotNone(runner._interpolation_model)
runner.shutdown()
def test_model_runner_process_batch_with_interpolate(self):
"""Test process_batch uses interpolation when enabled."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
runner = ModelRunner(parallel_config, "TEST_DEVICE", dp_rank=0)
request = Request(num_input_tokens=100, num_output_tokens=50)
request.state = RequestState.PREFILLING
request.query_len = 10
request.seq_len = 10
runner.process_batch([request])
runner.shutdown()
class TestAsyncTaskManagerFull(unittest.TestCase):
"""Tests for AsyncTaskManager actual initialization with Config."""
@classmethod
def setUpClass(cls):
"""Set up Config for AsyncTaskManager tests."""
cls.tmp_dir, cls.common_path, cls.instance_path = create_test_config_files()
Config._instance = None
Config._initialized = False
cls.config = Config(
MockParsedArgs(
instance_config_path=cls.instance_path,
common_config_path=cls.common_path,
)
)
@classmethod
def tearDownClass(cls):
"""Clean up temp files."""
import shutil
Config._instance = None
Config._initialized = False
shutil.rmtree(cls.tmp_dir, ignore_errors=True)
def test_async_task_manager_init_and_shutdown(self):
"""Test AsyncTaskManager initialization and shutdown."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
task_manager = AsyncTaskManager(
device_type="TEST_DEVICE",
parallel_config=parallel_config,
num_workers=2,
)
self.assertIsNotNone(task_manager.workers)
self.assertEqual(len(task_manager.workers), 2)
task_manager.shutdown()
self.assertEqual(len(task_manager.workers), 0)
def test_async_task_manager_add_task(self):
"""Test AsyncTaskManager add_task method."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
task_manager = AsyncTaskManager(
device_type="TEST_DEVICE",
parallel_config=parallel_config,
num_workers=2,
)
batch = [RequestInfo(query_len=10, seq_len=100, is_decode=False)]
task_manager.add_task(batch)
task_hash = AsyncTask(batch).hash_value
self.assertIn(task_hash, task_manager.task_record)
task_manager.add_task(batch)
task_manager.shutdown()
def test_async_task_manager_find_result_not_in_record(self):
"""Test AsyncTaskManager find_result with task not in record."""
parallel_config = ParallelConfig(
world_size=1,
tp_size=1,
moe_dp_size=1,
)
task_manager = AsyncTaskManager(
device_type="TEST_DEVICE",
parallel_config=parallel_config,
num_workers=2,
)
batch = [RequestInfo(query_len=10, seq_len=100, is_decode=False)]
result = task_manager.find_result(batch)
self.assertIsNone(result)
task_manager.shutdown()
if __name__ == "__main__":
unittest.main()