import json
import shutil
import re
from pathlib import Path
import tempfile
import pytest
from unittest.mock import patch, MagicMock, mock_open
import csv
import pandas as pd
from msserviceprofiler.modelevalstate.config.config import PerformanceIndex, get_settings, AisBenchConfig, \
OptimizerConfigField
from msserviceprofiler.modelevalstate.optimizer.plugins.benchmark import parse_result, AisBench, VllmBenchMark
from msserviceprofiler.msguard import GlobalConfig
settings = get_settings()
@pytest.fixture(scope="function")
def aisbench_test_environment():
"""创建AisBench测试环境的fixture"""
test_dir = Path(tempfile.mkdtemp())
result_dir = test_dir / "ais_bench"
result_dir.mkdir(exist_ok=True)
aisbench_dir = result_dir / "outputs"
aisbench_dir.mkdir(exist_ok=True)
performance_dir = aisbench_dir / "performances"
performance_dir.mkdir(exist_ok=True)
output_dir = performance_dir / "api_file"
output_dir.mkdir(exist_ok=True)
csv_path = output_dir / "gsm8kdataset.csv"
json_path = output_dir / "gsm8kdataset.json"
json_data = {
"Total Requests": {"total": 84},
"Success Requests": {"total": 84},
"Request Throughput": {"total": "2.4221 req/s"},
"Output Token Throughput": {"total": "1240.1267 token/s"},
"Concurrency": {"total": 40},
"Max Concurrency": {"total": 100}
}
with open(json_path, 'w') as f:
json.dump(json_data, f)
data = [
{
"Performance Parameters": "TTFT",
"Stage": "total",
"Average": "146.1383 ms",
},
{
"Performance Parameters": "TPOT",
"Stage": "total",
"Average": "30.2947 ms",
}
]
with open(csv_path, "w", newline="", encoding="utf-8") as csvfile:
fieldnames = data[0].keys()
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for row in data:
writer.writerow(row)
yield {
"test_dir": test_dir,
"csv_path": csv_path,
"json_path": json_path
}
shutil.rmtree(test_dir, ignore_errors=True)
@pytest.fixture(scope="function")
def before_run_test_environment():
"""创建BeforeRun测试环境的fixture"""
test_dir = Path(tempfile.mkdtemp())
benchmark_dir = test_dir / "benchmark"
benchmark_dir.mkdir(exist_ok=True)
config_dir = benchmark_dir / "configs"
config_dir.mkdir(exist_ok=True)
model_dir = config_dir / "models"
model_dir.mkdir(exist_ok=True)
code_content = '''from ais_bench.benchmark.models import VLLMCustomAPIChatStream
models = [
dict(
attr="service",
type=VLLMCustomAPIChatStream,
abbr='vllm-api-stream-chat',
path="/data/models/llama3-8b",
model="llama3-8b",
request_rate=36,
retry=2,
host_ip="127.0.0.1",
host_port=31015,
max_out_len=512,
batch_size=1000,
trust_remote_code=False,
generation_kwargs=dict(
temperature=0.5,
top_k=10,
top_p=0.95,
seed=None,
repetition_penalty=1.03,
)
)
]'''
target_file = model_dir / "api.py"
with open(target_file, 'w', encoding='utf-8') as f:
f.write(code_content)
yield {
"test_dir": test_dir,
"model_dir": model_dir,
"api_file_path": target_file
}
shutil.rmtree(test_dir, ignore_errors=True)
class TestAisbench:
"""使用pytest fixtures的AisBench测试类"""
@patch("msserviceprofiler.modelevalstate.config.custom_command.shutil.which")
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.AisBench.get_models_config_path')
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.glob.glob')
def test_get_performance_metric(self, mock_glob, mock_path, mock_which, aisbench_test_environment):
"""测试获取性能指标"""
mock_which.return_value = "/usr/local/bin/aisbench"
mock_path.return_value = aisbench_test_environment["csv_path"]
mock_glob.return_value = [str(aisbench_test_environment["csv_path"])]
config = AisBenchConfig()
config.output_path = aisbench_test_environment["test_dir"]
result = AisBench(config).get_performance_metric('ttft')
assert result == 0.1461383
@patch("msserviceprofiler.modelevalstate.config.custom_command.shutil.which")
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.AisBench.get_models_config_path')
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.glob.glob')
def test_get_performance_index(self, mock_glob, mock_path, mock_which, aisbench_test_environment):
"""测试获取性能索引"""
mock_which.return_value = "/usr/local/bin/aisbench"
mock_path.return_value = aisbench_test_environment["csv_path"]
mock_glob.side_effect = [
[str(aisbench_test_environment["csv_path"])],
[str(aisbench_test_environment["csv_path"])],
[str(aisbench_test_environment["json_path"])]
]
config = AisBenchConfig()
config.output_path = aisbench_test_environment["test_dir"]
performance_index = AisBench(config).get_performance_index()
assert performance_index.generate_speed == 1240.1267
@patch("msserviceprofiler.modelevalstate.config.custom_command.shutil.which")
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.AisBench.get_models_config_path')
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.glob.glob')
def test_get_best_concurrency(self, mock_glob, mock_path, mock_which, aisbench_test_environment):
"""测试获取最佳并发数"""
mock_which.return_value = "/usr/local/bin/aisbench"
mock_path.return_value = aisbench_test_environment["json_path"]
mock_glob.return_value = [str(aisbench_test_environment["json_path"])]
config = AisBenchConfig()
config.output_path = aisbench_test_environment["test_dir"]
config.best_concurrency_coefficient = 5.0
config.best_concurrency_threshold = 200
result = AisBench(config).get_best_concurrency()
assert result == 200
class TestBeforeRun:
"""使用pytest fixtures的BeforeRun测试类"""
@patch("msserviceprofiler.modelevalstate.config.custom_command.shutil.which")
@patch('msserviceprofiler.modelevalstate.optimizer.plugins.benchmark.AisBench.get_models_config_path')
def test_before_run_file_exists(self, mock_path, mock_which, before_run_test_environment):
"""测试文件存在且成功修改request_rate和batch_size的情况"""
mock_which.return_value = "/usr/local/bin/aisbench"
mock_module = MagicMock()
mock_module.__file__ = 'ais_bench/__init__.py'
mock_path.return_value = before_run_test_environment["api_file_path"]
support_field = [
OptimizerConfigField(name="CONCURRENCY",
config_position="env",
min=25, max=300, dtype="int", value=100),
OptimizerConfigField(name="REQUESTRATE",
config_position="env",
min=1, max=25, dtype="int", value=100)
]
GlobalConfig.custom_return = True
AisBench().before_run(support_field)
pattern = re.compile(r"request_rate\s*=\s*(\d+)")
with open(before_run_test_environment["api_file_path"], 'r', encoding='utf-8') as f:
content = f.read()
match = pattern.search(content)
assert int(match.group(1)) == 100
GlobalConfig.reset()