import unittest
from unittest.mock import patch, MagicMock
from pathlib import Path
import os
import shutil
import pandas as pd
from ms_serviceparam_optimizer.train.pretrain import PretrainModel, pretrain
from ms_serviceparam_optimizer.train.state_param import StateParam
from ms_serviceparam_optimizer.data_feature.dataset import MyDataSet
from ms_serviceparam_optimizer.model.xgb_state_model import StateXgbModel
class TestPretrainModel(unittest.TestCase):
def setUp(self):
self.test_dir = "test_data"
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir, 0o750)
self.input_path = Path(self.test_dir) / "input"
self.output_path = Path(self.test_dir) / "output"
self.feature_path = Path(self.input_path) / "test"
self.input_path.mkdir()
self.feature_path.mkdir()
self.output_path.mkdir()
pretrain_data_path = os.path.join(os.path.dirname(__file__), 'pretrain_data.json')
if os.path.exists(pretrain_data_path):
self.real_data = pd.read_csv(pretrain_data_path)
self.real_data_file = Path(self.feature_path) / "feature.csv"
self.real_data.to_csv(self.real_data_file, index=False)
if self.real_data is not None:
self.real_dataset = MyDataSet()
sp = StateParam(
base_path=self.output_path,
predict_field="model_execute_time",
save_model=True,
shuffle=True,
plot_pred_and_real=True,
plot_data_feature=True,
start_num_lines=4000,
title="MixModel without warmup with batch max seq 2 op info",
)
self.state_param = sp
self.real_model = StateXgbModel(
train_param=sp.xgb_model_train_param,
update_param=sp.xgb_model_update_param,
save_model_path=sp.xgb_model_save_model_path,
load_model_path=sp.xgb_model_save_model_path,
show_test_data_prediction=sp.xgb_model_show_test_data_prediction,
show_feature_importance=sp.xgb_model_show_feature_importance,
)
self.real_trainer = PretrainModel(
state_param=sp, dataset=self.real_dataset, model=self.real_model, plt_data=False
)
self.dataset = MagicMock()
self.model = MagicMock()
self.pm = PretrainModel(state_param=sp, dataset=self.dataset, model=self.model)
self.trainer = PretrainModel(state_param=sp, dataset=self.dataset, model=self.model, plt_data=False)
self.sample_data = pd.DataFrame(
{
'batch_stage': ['prefill', 'decode', 'prefill'],
'batch_size': [1, 2, 3],
'model_execute_time': [100, 200, 300],
}
)
self.features = pd.DataFrame({'batch_stage': [1, 0], 'batch_size': [1, 2], 'feature1': [0.5, 0.7]})
self.labels = pd.Series([100, 200])
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_train_with_sample_data(self):
"""测试使用样本数据训练"""
self.dataset.construct_data.return_value = None
self.model.train.return_value = 0.25
self.trainer.train(lines_data=self.sample_data)
self.dataset.construct_data.assert_called_once_with(self.sample_data, plt_data=False, middle_save_path=None)
self.model.train.assert_called_once_with(self.dataset, middle_save_path=None)
self.assertIn(0.25, self.trainer.rmse)
def test_get_up_ud(self):
class Node:
def __init__(self, stage, batch_size, time):
self.stage = stage
self.batch_size = batch_size
self.model_execute_time = time
nodes = [Node('prefill', 1, 100), Node('decode', 2, 200)]
prefill, decode = PretrainModel.get_up_ud(tuple(nodes), "model_execute_time")
self.assertEqual(prefill[list(prefill)[0]], [10000.0])
self.assertEqual(decode[list(decode)[0]], [5000.0])
def test_model_backup(self):
"""测试模型备份功能"""
model_file = self.state_param.xgb_model_save_model_path
model_file.parent.mkdir(parents=True, exist_ok=True)
with open(model_file, "w", encoding="utf-8") as f:
f.write("dummy model data")
self.trainer.bak_model(increment_stage="test_stage")
bak_file = self.state_param.bak_dir / "test_stage" / model_file.name
self.assertTrue(bak_file.exists())
@patch("matplotlib.pyplot.savefig")
@patch("matplotlib.pyplot.close")
def test_metric_visualization(self, mock_close, mock_savefig):
"""测试指标可视化"""
self.trainer.rmse = [0.25, 0.18, 0.15]
self.trainer.r2 = [0.92, 0.94, 0.96]
self.trainer.mape = [0.05, 0.04, 0.03]
save_path = Path(self.test_dir) / "metrics"
save_path.mkdir()
self.trainer.plot_metric(save_path=save_path)
mock_savefig.assert_called_once_with(save_path / "metric.png")
self.assertTrue(mock_savefig.called)
def test_pretrain_workflow(self):
"""测试完整的训练流程"""
pretrain(str(self.input_path), str(self.output_path))
self.assertTrue((self.output_path / "model").exists())
self.assertTrue((self.output_path / "step").exists())
self.assertTrue((self.output_path / "bak").exists())
self.assertTrue((self.output_path / "model" / "xgb_model.ubj").exists())
self.assertTrue((self.output_path / "cache" / "train_data.csv").exists())
def test_get_decode_and_prefill_time(self):
class Node:
def __init__(self, stage, batch_size, time):
self.stage = stage
self.batch_size = batch_size
self.model_execute_time = time
nodes = [Node('prefill', 1, 100), Node('decode', 2, 200), Node('prefill', 3, 300)]
prefill, decode = PretrainModel.get_decode_and_prefill_time(tuple(nodes), "model_execute_time")
self.assertEqual(prefill[list(prefill)[0]], [100])
self.assertEqual(prefill[list(prefill)[1]], [300])
self.assertEqual(decode[list(decode)[0]], [200])
def test_train(self):
with patch.object(self.dataset, 'construct_data') as mock_construct:
with patch.object(self.model, 'train') as mock_train:
mock_train.return_value = 0.5
self.pm.train(self.sample_data)
mock_construct.assert_called_once_with(self.sample_data, plt_data=False, middle_save_path=None)
mock_train.assert_called_once_with(self.dataset, middle_save_path=None)
self.assertEqual(self.pm.rmse, [0.5])
def test_partial_train(self):
with patch.object(self.dataset.custom_encoder, 'fit') as mock_fit:
with patch.object(self.dataset, 'construct_data') as mock_construct:
with patch.object(self.model, 'train') as mock_train:
mock_train.return_value = 0.3
self.pm.partial_train(self.sample_data)
mock_fit.assert_called_once_with(load=True)
mock_construct.assert_called_once_with(self.sample_data, plt_data=False, middle_save_path=None)
mock_train.assert_called_once_with(self.dataset, train_type="partial_fit", middle_save_path=None)
self.assertEqual(self.pm.rmse, [0.3])
def test_initialization(self):
"""测试训练器初始化"""
self.assertIs(self.trainer.state_param, self.state_param)
self.assertIs(self.trainer.dataset, self.dataset)
self.assertIs(self.trainer.model, self.model)
self.assertFalse(self.trainer.plt_data)
if __name__ == '__main__':
unittest.main()