import os
import threading
import time
from pathlib import Path
from unittest.mock import patch, MagicMock
import numpy as np
import pytest
import torch
from ms_serviceparam_optimizer.config.config import get_settings
from ms_serviceparam_optimizer.inference.constant import IS_SLEEP_FLAG
from ms_serviceparam_optimizer.inference.data_format_v1 import BatchField, RequestField
from ms_serviceparam_optimizer.inference.simulate import Simulate, predict_queue, ServiceField, FileLogger
class TestFileLogger:
@staticmethod
def test_open_file_with_path(logger, file_path):
logger.file_path = file_path
logger.mode = 'w'
logger.open_file()
assert logger.fout is not None
assert logger.fout.closed is False
logger.fout.close()
@staticmethod
def test_open_file_with_string(logger, file_path):
logger.file_path = str(file_path)
logger.mode = 'w'
logger.open_file()
assert logger.fout is not None
assert logger.fout.closed is False
logger.fout.close()
@staticmethod
def test_open_file_with_invalid_path(logger):
logger.file_path = None
logger.mode = 'w'
with pytest.raises(TypeError):
logger.open_file()
@staticmethod
def test_open_file_with_invalid_mode(logger, file_path):
logger.file_path = file_path
logger.mode = 'x'
logger.open_file()
@staticmethod
def test_write_with_none_fout(logger, file_path):
logger.fout = None
logger.lock = threading.Lock()
message = "test message"
logger.write(message)
@staticmethod
def test_write_with_not_none_fout(logger, file_path):
mock_file = MagicMock()
logger.fout = mock_file
logger.lock = threading.Lock()
message = "test message"
logger.write(message)
mock_file.write.assert_any_call(message)
mock_file.write.assert_any_call("\n")
mock_file.flush.assert_called_once()
@pytest.fixture
def logger(self):
return FileLogger(Path(get_settings().simulator_output).joinpath(f"simulate_{os.getpid()}.csv"))
@pytest.fixture
def file_path(self):
return Path("test.log")
class TestSimulate:
@staticmethod
def test_generate_random_token_shape(plugin_object):
shape = (2, 3)
max_value = 32000
result = Simulate.generate_random_token(plugin_object, shape, max_value)
assert result.shape == shape, "Generated array shape does not match the expected shape."
@staticmethod
def test_generate_random_token_eos_token_replacement(plugin_object):
shape = (2, 3)
max_value = 32000
result = Simulate.generate_random_token(plugin_object, shape, max_value)
assert plugin_object.eos_token_id not in result, "eos_token_id should be replaced in the generated array."
@staticmethod
def test_generate_random_token_value_range(plugin_object):
shape = (2, 3)
max_value = 32000
result = Simulate.generate_random_token(plugin_object, shape, max_value)
assert result.min() >= 0 and result.max() <= max_value, "Generated array values are out of the expected range."
@staticmethod
def test_generate_random_token_no_replacement_needed(plugin_object):
shape = (2, 3)
max_value = 32000
result = Simulate.generate_random_token(plugin_object, shape, max_value)
assert result.size == np.prod(shape), "The size of the generated array does not match the product of the shape."
@pytest.fixture
def plugin_object(self):
class PluginObject:
def __init__(self):
self.eos_token_id = 10000
return PluginObject()
def test_generate_logits():
device = "cpu"
for batch_size in [1, 2, 10]:
for vocab_size in [128, 1024, 129280]:
logits = Simulate.generate_logits(batch_size, vocab_size, device=device)
assert logits.shape == (batch_size, vocab_size), (
f"Expected shape ({batch_size}, {vocab_size}), got {logits.shape}"
)
for dtype in ["float16", "bfloat16", "float"]:
logits = Simulate.generate_logits(1, device=device, dtype=dtype)
if dtype == "float16":
assert logits.dtype == torch.float16, f"Expected dtype torch.float16, got {logits.dtype}"
elif dtype == "bfloat16":
assert logits.dtype == torch.bfloat16, f"Expected dtype torch.bfloat16, got {logits.dtype}"
elif dtype == "float":
assert logits.dtype == torch.float, f"Expected dtype torch.float, got {logits.dtype}"
class TestSimulateUpdateToken:
@staticmethod
def test_update_token_with_eos_token(setup):
plugin_object, input_metadata, cached_ids, sampling_output = setup
ServiceField.req_id_and_max_decode_length = {0: 10}
Simulate.update_token(plugin_object, input_metadata, cached_ids, sampling_output)
assert sampling_output.token_ids[0].item() == plugin_object.eos_token_id
assert sampling_output.top_token_ids.size == 0
@staticmethod
def test_update_token_with_top_token(setup):
plugin_object, input_metadata, cached_ids, sampling_output = setup
ServiceField.req_id_and_max_decode_length = {0: 10}
sampling_output.top_token_ids = np.array([[50256]])
Simulate.update_token(plugin_object, input_metadata, cached_ids, sampling_output)
assert sampling_output.token_ids[0].item() == plugin_object.eos_token_id
assert sampling_output.top_token_ids[0].item() != plugin_object.eos_token_id
@staticmethod
def test_update_token_with_max_length_reached(setup):
plugin_object, input_metadata, cached_ids, sampling_output = setup
ServiceField.req_id_and_max_decode_length = {0: 1}
sampling_output.token_ids = np.array([50224])
Simulate.update_token(plugin_object, input_metadata, cached_ids, sampling_output)
assert sampling_output.token_ids[0] == plugin_object.eos_token_id
assert sampling_output.top_token_ids.size == 0
@staticmethod
def test_update_token_with_max_length_reached_with_top_token(setup):
plugin_object, input_metadata, cached_ids, sampling_output = setup
ServiceField.req_id_and_max_decode_length = {0: 1}
sampling_output.token_ids = np.array([50224])
sampling_output.top_token_ids = np.array([[50224]])
Simulate.update_token(plugin_object, input_metadata, cached_ids, sampling_output)
assert sampling_output.token_ids[0] == plugin_object.eos_token_id
assert sampling_output.top_token_ids[0].item() == plugin_object.eos_token_id
@staticmethod
def test_update_token_with_no_request_id(setup):
plugin_object, input_metadata, cached_ids, sampling_output = setup
ServiceField.req_id_and_max_decode_length = {}
Simulate.update_token(plugin_object, input_metadata, cached_ids, sampling_output)
assert sampling_output.token_ids[0] == plugin_object.eos_token_id
assert sampling_output.top_token_ids.size == 0
@pytest.fixture
def setup(self):
plugin_object = MagicMock()
plugin_object.input_manager.cache.output_len_count = np.full((3, 1), 1, dtype=np.int32)
plugin_object.eos_token_id = 50256
plugin_object.model_wrapper.config.vocab_size = 50256
input_metadata = MagicMock()
input_metadata.batch_request_ids = np.array([0])
cached_ids = 0
sampling_output = MagicMock()
sampling_output.token_ids = np.array([50256])
sampling_output.top_token_ids = np.zeros((1, 0), dtype=np.int32)
return plugin_object, input_metadata, cached_ids, sampling_output
class MockFileHandler:
def __init__(self):
pass
class MockDataProcessor:
def __init__(self):
pass
class MockConfigPath:
def __init__(self):
self.cache_data = {}
@pytest.fixture(name="predict_setup")
def fixture_predict_setup():
Simulate.predict_cache = {}
ServiceField.batch_field = BatchField("decode", 20, 20.0, 580.0, 29.0)
ServiceField.request_field = tuple(RequestField(29.0, 1, 2) for _ in range(20))
ServiceField.config_path = MockConfigPath()
ServiceField.fh = MockFileHandler()
ServiceField.data_processor = MockDataProcessor()
def test_predict_with_sleep(predict_setup, monkeypatch):
monkeypatch.setattr(
"ms_serviceparam_optimizer.inference.simulate.predict_v1_with_cache", lambda *args, **kwargs: (-1, 300000)
)
assert len(Simulate.predict_cache) == 0
st = time.perf_counter()
os.environ[IS_SLEEP_FLAG] = "true"
result = Simulate.predict()
assert time.perf_counter() - st > 0.3
assert len(Simulate.predict_cache) == 1
assert result == 300000
result = Simulate.predict()
assert result == 300000
assert len(Simulate.predict_cache) == 1
def test_predict_without_sleep(predict_setup, monkeypatch):
monkeypatch.setattr(
"ms_serviceparam_optimizer.inference.simulate.predict_v1_with_cache", lambda *args, **kwargs: (-1, 300000)
)
os.environ[IS_SLEEP_FLAG] = "false"
assert len(Simulate.predict_cache) == 0
st = time.perf_counter()
result = Simulate.predict()
assert time.perf_counter() - st < 0.3
assert result == 300000
assert len(Simulate.predict_cache) == 1
result = Simulate.predict()
assert result == 300000
assert len(Simulate.predict_cache) == 1
@patch('ms_serviceparam_optimizer.inference.simulate.Simulate.predict')
def test_predict_and_save(mock_predict):
mock_predict.return_value = MagicMock()
Simulate.predict_and_save()
mock_predict.assert_called_once()
assert predict_queue.qsize() == 1
while not predict_queue.empty():
predict_queue.get()
Simulate.predict_and_save(time_sleep=True)
assert mock_predict.call_count == 2
assert predict_queue.qsize() == 1
while not predict_queue.empty():
predict_queue.get()