import argparse
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import os
import json
import shutil
import sqlite3
import pandas as pd
from ms_serviceparam_optimizer.train.source_to_train import (
DatabaseConnector,
source_to_model,
req_decodetimes,
read_batch_exec_data,
)
class TestSourceToTrainMindie(unittest.TestCase):
"""测试数据预处理和训练流程的功能"""
def setUp(self):
self.test_dir = Path("test_source_to_train")
self.test_dir.mkdir(exist_ok=True)
self.db_path = self.test_dir / "profiler.db"
self.create_sample_database()
self.create_sample_csv()
def tearDown(self):
shutil.rmtree(self.test_dir)
def create_sample_database(self):
"""创建样本SQLite数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch (
name TEXT,
res_list TEXT,
start_time REAL,
end_time REAL,
batch_size REAL,
batch_type TEXT,
during_time REAL
)
""")
batch_data = [
("BatchSchedule", "[{'rid': 101, 'iter': 0}]", 1749451414153, 1749451414154, 1, "Prefill", 0.22175),
("BatchSchedule", "[{'rid': 101, 'iter': 0}]", 1749451414154, 1749451414155, 1, "Decode", 0.223),
]
cursor.executemany(
"INSERT INTO batch (name,res_list, start_time,end_time,batch_size,batch_type,"
"during_time) VALUES (?, ?, ?, ?, ?, ?, ?)",
batch_data,
)
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch_exec (
batch_id INTEGER PRIMARY KEY,
event TEXT,
pid INTEGER,
start REAL,
end REAL
)
""")
exec_data = [(1, 'forward', 1001, 1000.0, 1500.0), (2, 'forward', 1001, 2000.0, 2500.0)]
cursor.executemany(
"INSERT INTO batch_exec (batch_id, event, pid, start, end) VALUES (?, ?, ?, ?, ?)", exec_data
)
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch_req (
batch_id INTEGER,
req_id TEXT,
rid TEXT,
iter INTEGER,
block INTEGER
)
""")
req_data = [(1, "101", "101", "0", 256), (2, "101", "101", "1", 192)]
cursor.executemany(
"INSERT INTO batch_req (batch_id, req_id, rid, iter, block) VALUES (?, ?, ?, ?, ?)", req_data
)
conn.commit()
conn.close()
def create_sample_csv(self):
"""创建样本CSV文件"""
request_data = pd.DataFrame(
{
"http_rid": ["101"],
"start_time": ["1749451414153"],
"recv_token_size": ["256"],
"reply_token_size": ["128"],
"execution_time": ["1"],
"queue_wait_time": ["0.11"],
"first_token_latency": ["0.5"],
}
)
request_data.to_csv(self.test_dir / "request.csv", index=False)
def test_database_connector(self):
"""测试数据库连接器"""
db_conn = DatabaseConnector(str(self.db_path))
cursor = db_conn.connect()
self.assertIsNotNone(cursor)
exec_rows = read_batch_exec_data(cursor)
self.assertEqual(len(exec_rows), 2)
db_conn.close()
def test_source_to_model_mindie(self):
"""测试Mindie数据预处理流程"""
source_to_model(self.test_dir, model_type='mindie')
output_csv = self.test_dir / "output_csv"
self.assertTrue(output_csv.exists())
for pid_dir in output_csv.iterdir():
if pid_dir.is_dir():
self.assertTrue((pid_dir / "feature.csv").exists())
def test_req_decodetimes(self):
"""测试解码时间处理"""
json_file = self.test_dir / "output" / "req_id_and_decode_num.json"
os.makedirs(os.path.dirname(json_file), exist_ok=True)
with open(json_file, "w", encoding="utf-8") as f:
f.write("")
req_decodetimes(self.test_dir, self.test_dir / "output")
self.assertTrue(json_file.exists())
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.assertEqual(len(data), 1)
self.assertEqual(data["0"], 128)
class TestSourceToTrainVllm(unittest.TestCase):
"""测试数据预处理和训练流程的功能"""
def setUp(self):
self.test_dir = Path("test_source_to_train")
self.test_dir.mkdir(exist_ok=True)
self.db_path = self.test_dir / "profiler.db"
self.create_sample_database()
self.create_sample_csv()
def tearDown(self):
shutil.rmtree(self.test_dir)
def create_sample_database(self):
"""创建样本SQLite数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch (
name TEXT,
res_list TEXT,
start_time REAL,
end_time REAL,
batch_size REAL,
batch_type TEXT,
during_time REAL
)
""")
batch_data = [
(
"batchFrameworkProcessing",
"[{'rid': 101, 'iter_size': 0}]",
1749451414153,
1749451414154,
1,
"Prefill",
0.22175,
),
(
"batchFrameworkProcessing",
"[{'rid': 101, 'iter_size': 0}]",
1749451414154,
1749451414155,
1,
"Decode",
0.223,
),
]
cursor.executemany(
"INSERT INTO batch (name,res_list, start_time,end_time,batch_size,batch_type,"
"during_time) VALUES (?, ?, ?, ?, ?, ?, ?)",
batch_data,
)
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch_exec (
batch_id INTEGER PRIMARY KEY,
event TEXT,
pid INTEGER,
start REAL,
end REAL
)
""")
exec_data = [(1, 'forward', 1001, 1000.0, 1500.0), (2, 'forward', 1001, 2000.0, 2500.0)]
cursor.executemany(
"INSERT INTO batch_exec (batch_id, event, pid, start, end) VALUES (?, ?, ?, ?, ?)", exec_data
)
cursor.execute("""
CREATE TABLE IF NOT EXISTS batch_req (
batch_id INTEGER,
req_id TEXT,
rid TEXT,
iter_size INTEGER
)
""")
req_data = [(1, "101", "101", "0"), (2, "101", "101", "1")]
cursor.executemany("INSERT INTO batch_req (batch_id, req_id, rid, iter_size) VALUES (?, ?, ?, ?)", req_data)
conn.commit()
conn.close()
def create_sample_csv(self):
"""创建样本CSV文件"""
request_data = pd.DataFrame(
{
"http_rid": ["101"],
"start_time": ["1749451414153"],
"recv_token_size": ["256"],
"reply_token_size": ["128"],
"execution_time": ["1"],
"queue_wait_time": ["0.11"],
"first_token_latency": ["0.5"],
}
)
request_data.to_csv(self.test_dir / "request.csv", index=False)
kvcache_data = pd.DataFrame(
{
"domain": ["KVCache", "KVCache"],
"rid": ["101", "101"],
"timestamp": ["1749451415160", "1749451415161"],
"name": ["Allocate", "blocks"],
"device_kvcache_left": ["128", "256"],
}
)
kvcache_data.to_csv(self.test_dir / "kvcache.csv", index=False)
def test_source_to_model_vllm(self):
"""测试vLLM数据预处理流程"""
source_to_model(self.test_dir, model_type="vllm")
output_csv = self.test_dir / "output_csv"
self.assertTrue(output_csv.exists())
for pid_dir in output_csv.iterdir():
if pid_dir.is_dir():
self.assertTrue((pid_dir / "feature.csv").exists())
class TestArgParseAndMain(unittest.TestCase):
"""测试arg_parse和main函数的功能"""
def setUp(self):
self.test_dir = Path("test_source_to_train_main")
self.test_dir.mkdir(exist_ok=True)
def tearDown(self):
shutil.rmtree(self.test_dir)
@patch('ms_serviceparam_optimizer.train.source_to_train.main')
def test_arg_parse(self, mock_main):
"""测试arg_parse函数"""
from ms_serviceparam_optimizer.train.source_to_train import arg_parse
mock_subparsers = MagicMock()
mock_parser = MagicMock()
mock_subparsers.add_parser.return_value = mock_parser
arg_parse(mock_subparsers)
mock_subparsers.add_parser.assert_called_once_with(
"train", formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="train for auto optimize"
)
self.assertEqual(mock_parser.add_argument.call_count, 3)
mock_parser.set_defaults.assert_called_once_with(func=mock_main)
@patch('ms_serviceparam_optimizer.train.source_to_train.req_decodetimes')
@patch('ms_serviceparam_optimizer.train.pretrain.pretrain')
@patch('ms_serviceparam_optimizer.train.source_to_train.source_to_model')
@patch('ms_serviceparam_optimizer.train.source_to_train.is_root')
@patch('ms_serviceparam_optimizer.train.source_to_train.logger')
def test_main_with_root_user(
self, mock_logger, mock_is_root, mock_source_to_model, mock_pretrain, mock_req_decodetimes
):
"""测试main函数在root用户下的行为"""
from ms_serviceparam_optimizer.train.source_to_train import main
mock_is_root.return_value = True
mock_args = MagicMock()
mock_args.input = self.test_dir
mock_args.output = self.test_dir / "output"
mock_args.type = "mindie"
main(mock_args)
mock_logger.warning.assert_called_once()
self.assertIn("Security Warning", mock_logger.warning.call_args[0][0])
mock_source_to_model.assert_called_once_with(self.test_dir, "mindie")
mock_pretrain.assert_called_once()
mock_req_decodetimes.assert_called_once()
@patch('ms_serviceparam_optimizer.train.source_to_train.req_decodetimes')
@patch('ms_serviceparam_optimizer.train.pretrain.pretrain')
@patch('ms_serviceparam_optimizer.train.source_to_train.source_to_model')
@patch('ms_serviceparam_optimizer.train.source_to_train.is_root')
@patch('ms_serviceparam_optimizer.train.source_to_train.logger')
def test_main_with_non_root_user(
self, mock_logger, mock_is_root, mock_source_to_model, mock_pretrain, mock_req_decodetimes
):
"""测试main函数在非root用户下的行为"""
from ms_serviceparam_optimizer.train.source_to_train import main
mock_is_root.return_value = False
mock_args = MagicMock()
mock_args.input = self.test_dir
mock_args.output = self.test_dir / "output"
mock_args.type = "mindie"
main(mock_args)
mock_logger.warning.assert_not_called()
mock_source_to_model.assert_called_once_with(self.test_dir, "mindie")
mock_pretrain.assert_called_once()
mock_req_decodetimes.assert_called_once()
@patch('ms_serviceparam_optimizer.train.source_to_train.req_decodetimes')
@patch('ms_serviceparam_optimizer.train.pretrain.pretrain')
@patch('ms_serviceparam_optimizer.train.source_to_train.source_to_model')
@patch('ms_serviceparam_optimizer.train.source_to_train.is_root')
@patch('ms_serviceparam_optimizer.train.source_to_train.logger')
def test_main_with_vllm_type(
self, mock_logger, mock_is_root, mock_source_to_model, mock_pretrain, mock_req_decodetimes
):
"""测试main函数使用vllm类型"""
from ms_serviceparam_optimizer.train.source_to_train import main
mock_is_root.return_value = False
mock_args = MagicMock()
mock_args.input = self.test_dir
mock_args.output = self.test_dir / "output"
mock_args.type = "vllm"
main(mock_args)
mock_source_to_model.assert_called_once_with(self.test_dir, "vllm")
mock_pretrain.assert_called_once()
mock_req_decodetimes.assert_called_once()
@patch('ms_serviceparam_optimizer.train.source_to_train.req_decodetimes')
@patch('ms_serviceparam_optimizer.train.pretrain.pretrain')
@patch('ms_serviceparam_optimizer.train.source_to_train.source_to_model')
@patch('ms_serviceparam_optimizer.train.source_to_train.is_root')
@patch('ms_serviceparam_optimizer.train.source_to_train.logger')
def test_main_with_io_error(
self, mock_logger, mock_is_root, mock_source_to_model, mock_pretrain, mock_req_decodetimes
):
"""测试main函数处理IOError的情况"""
from ms_serviceparam_optimizer.train.source_to_train import main
mock_is_root.return_value = False
mock_source_to_model.side_effect = IOError("File not found")
mock_args = MagicMock()
mock_args.input = self.test_dir
mock_args.output = self.test_dir / "output"
mock_args.type = "mindie"
with self.assertRaises(IOError):
main(mock_args)
mock_logger.error.assert_called_once()
self.assertIn("无法读取输入文件", mock_logger.error.call_args[0][0])
mock_pretrain.assert_not_called()
mock_req_decodetimes.assert_not_called()
@patch('ms_serviceparam_optimizer.train.source_to_train.req_decodetimes')
@patch('ms_serviceparam_optimizer.train.pretrain.pretrain')
@patch('ms_serviceparam_optimizer.train.source_to_train.source_to_model')
@patch('ms_serviceparam_optimizer.train.source_to_train.is_root')
@patch('ms_serviceparam_optimizer.train.source_to_train.logger')
def test_main_with_pretrain_error(
self, mock_logger, mock_is_root, mock_source_to_model, mock_pretrain, mock_req_decodetimes
):
"""测试main函数处理pretrain异常的情况"""
from ms_serviceparam_optimizer.train.source_to_train import main
mock_is_root.return_value = False
mock_pretrain.side_effect = Exception("Pretrain failed")
mock_args = MagicMock()
mock_args.input = self.test_dir
mock_args.output = self.test_dir / "output"
mock_args.type = "mindie"
main(mock_args)
mock_logger.error.assert_called_once()
self.assertIn("pretrain failed", mock_logger.error.call_args[0][0])
mock_source_to_model.assert_called_once_with(self.test_dir, "mindie")
mock_req_decodetimes.assert_not_called()
if __name__ == "__main__":
unittest.main()