import unittest
from unittest.mock import patch, MagicMock
from pathlib import Path
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from ms_serviceparam_optimizer.analysis import (
AnalysisState,
PlotConfig,
State
)
class TestAnalysisState(unittest.TestCase):
def setUp(self):
self.test_data = {
State(batch_prefill=1): [10.0, 10.5, 11.0],
State(batch_prefill=2): [20.0, 20.5, 21.0],
State(batch_decode=3): [30.0, 30.5, 31.0],
State(batch_decode=4): [40.0, 40.5, 41.0],
}
self.single_data = {
State(batch_prefill=1): [10.0],
State(batch_prefill=2): [20.0],
}
self.save_path = Path("/tmp/test_save_path")
self.save_path.mkdir(exist_ok=True, parents=True)
def tearDown(self):
for file in self.save_path.iterdir():
if file.is_file():
file.unlink()
self.save_path.rmdir()
@patch('matplotlib.pyplot.plot')
@patch('matplotlib.pyplot.show')
@patch('matplotlib.pyplot.close')
def test_computer_mean_sigma(self, mock_close, mock_show, mock_plot):
test_data = {
State(batch_prefill=1): [10.0, 10.5, 11.0],
State(batch_prefill=1, batch_decode=10): [10.1, 10.6, 11.1],
State(batch_prefill=2): [20.0, 20.5, 21.0],
State(batch_prefill=2, batch_decode=20): [20.1, 20.6, 21.1],
}
group1_data = [10.0, 10.5, 11.0, 10.1, 10.6, 11.1]
group2_data = [20.0, 20.5, 21.0, 20.1, 20.6, 21.1]
expected_mean1 = np.mean(group1_data)
expected_mean2 = np.mean(group2_data)
x, mean, pos_sigma, neg_sigma = AnalysisState.computer_mean_sigma(
test_data, "batch_prefill"
)
self.assertIsInstance(x, list)
self.assertIsInstance(mean, list)
self.assertIsInstance(pos_sigma, list)
self.assertIsInstance(neg_sigma, list)
self.assertEqual(len(x), 2)
self.assertAlmostEqual(mean[0], expected_mean1, places=2)
self.assertAlmostEqual(mean[1], expected_mean2, places=2)
@patch('matplotlib.pyplot.plot')
@patch('matplotlib.pyplot.legend')
@patch('matplotlib.pyplot.grid')
@patch('matplotlib.pyplot.title')
@patch('matplotlib.pyplot.xlabel')
@patch('matplotlib.pyplot.ylabel')
@patch('matplotlib.pyplot.savefig')
@patch('matplotlib.pyplot.close')
def test_plot_input_velocity(self, mock_close, mock_savefig, mock_ylabel,
mock_xlabel, mock_title, mock_grid, mock_legend,
mock_plot):
config = PlotConfig(
data=self.test_data,
x_field="batch_prefill",
title="Test Plot",
x_label="Batch Size",
y_label="Latency (ms)",
save_path=str(self.save_path)
)
AnalysisState.plot_input_velocity(config)
self.assertEqual(mock_plot.call_count, 3)
mock_title.assert_called_once_with("Test Plot")
mock_xlabel.assert_called_once_with("Batch Size")
mock_ylabel.assert_called_once_with("Latency (ms)")
mock_legend.assert_called_once()
mock_grid.assert_called_once()
mock_savefig.assert_called_once()
mock_close.assert_called_once()
config.save_path = None
mock_plot.reset_mock()
mock_show = MagicMock()
with patch('matplotlib.pyplot.show', mock_show):
AnalysisState.plot_input_velocity(config)
mock_show.assert_called_once()
@patch('matplotlib.pyplot.figure')
@patch('matplotlib.pyplot.scatter')
@patch('matplotlib.pyplot.title')
@patch('matplotlib.pyplot.xlabel')
@patch('matplotlib.pyplot.ylabel')
@patch('matplotlib.pyplot.legend')
@patch('matplotlib.pyplot.savefig')
@patch('matplotlib.pyplot.close')
@patch('matplotlib.pyplot.show')
def test_plot_pred_and_real(self, mock_show, mock_close, mock_savefig,
mock_legend, mock_ylabel, mock_xlabel,
mock_title, mock_scatter, mock_figure):
pred = [1.1, 2.1, 3.1]
real = [1.0, 2.0, 3.0]
AnalysisState.plot_pred_and_real(pred, real, self.save_path)
self.assertEqual(mock_scatter.call_count, 2)
mock_title.assert_called_once_with("predict value and real value")
mock_xlabel.assert_called_once_with("index")
mock_ylabel.assert_called_once_with("value")
mock_legend.assert_called_once()
mock_savefig.assert_called_once_with(self.save_path / "predict value and real value.png")
mock_close.assert_called_once()
mock_scatter.reset_mock()
mock_savefig.reset_mock()
AnalysisState.plot_pred_and_real(pred, real, None)
mock_show.assert_called_once()
def test_std_calculations(self):
test_data = {
State(batch_prefill=1): [1.0, 2.0, 3.0]
}
x, mean, pos_sigma, neg_sigma = AnalysisState.computer_mean_sigma(
test_data, "batch_prefill"
)
self.assertAlmostEqual(mean[0], 2.0, places=1)
self.assertAlmostEqual(pos_sigma[0], 3.0, places=1)
test_single_point = {
State(batch_prefill=1): [1.0]
}
x, mean, pos_sigma, neg_sigma = AnalysisState.computer_mean_sigma(
test_single_point, "batch_prefill"
)
self.assertEqual(mean[0], 1.0)
self.assertEqual(pos_sigma[0], 1.0)
self.assertEqual(neg_sigma[0], 1.0)
@patch('matplotlib.pyplot.plot')
@patch('matplotlib.pyplot.title')
@patch('matplotlib.pyplot.legend')
@patch('matplotlib.pyplot.grid')
@patch('matplotlib.pyplot.xlabel')
@patch('matplotlib.pyplot.ylabel')
@patch('matplotlib.pyplot.savefig')
@patch('matplotlib.pyplot.close')
@patch('matplotlib.pyplot.show')
def test_plot_input_velocity_with_df(self, mock_show, mock_close, mock_savefig,
mock_ylabel, mock_xlabel, mock_grid,
mock_legend, mock_title, mock_plot):
with patch.object(AnalysisState, 'plot_input_velocity_with_df') as mock_method:
def mock_implementation(predict_df, origin_df, save_path):
batch_stages = ['prefill', 'decode']
for batch_stage in batch_stages:
plt.plot([1, 2], [15.0, 35.0], label="predict mean")
plt.plot([1, 2], [20.0, 40.0], label="predict positive std")
plt.plot([1, 2], [10.0, 30.0], label="predict negative std")
plt.plot([1, 2], [17.0, 37.0], label="origin mean")
plt.plot([1, 2], [22.0, 42.0], label="origin positive std")
plt.plot([1, 2], [12.0, 32.0], label="origin negative std")
plt.title(f"{batch_stage} latency")
plt.legend()
plt.grid()
plt.xlabel("batch size")
plt.ylabel("res")
if save_path:
plt.savefig(Path(save_path) / f"{batch_stage}_batch_size_res.png")
plt.close()
else:
plt.show()
mock_method.side_effect = mock_implementation
mock_predict_df = MagicMock()
mock_origin_df = MagicMock()
AnalysisState.plot_input_velocity_with_df(mock_predict_df, mock_origin_df, self.save_path)
self.assertEqual(mock_plot.call_count, 12)
self.assertEqual(mock_title.call_count, 2)
self.assertEqual(mock_xlabel.call_count, 2)
self.assertEqual(mock_ylabel.call_count, 2)
self.assertEqual(mock_legend.call_count, 2)
self.assertEqual(mock_grid.call_count, 2)
self.assertEqual(mock_savefig.call_count, 2)
self.assertEqual(mock_close.call_count, 2)
expected_calls = [
(self.save_path / "prefill_batch_size_res.png",),
(self.save_path / "decode_batch_size_res.png",)
]
actual_calls = [call[0] for call in mock_savefig.call_args_list]
self.assertEqual(actual_calls, expected_calls)
mock_plot.reset_mock()
mock_savefig.reset_mock()
mock_close.reset_mock()
mock_show.reset_mock()
AnalysisState.plot_input_velocity_with_df(mock_predict_df, mock_origin_df, None)
self.assertEqual(mock_show.call_count, 2)
@patch('matplotlib.pyplot.figure')
@patch('matplotlib.pyplot.plot')
@patch('matplotlib.pyplot.title')
@patch('matplotlib.pyplot.legend')
@patch('matplotlib.pyplot.grid')
@patch('matplotlib.pyplot.xlabel')
@patch('matplotlib.pyplot.ylabel')
@patch('matplotlib.pyplot.savefig')
@patch('matplotlib.pyplot.close')
@patch('matplotlib.pyplot.show')
@patch('ms_serviceparam_optimizer.analysis.open_s')
def test_plot_input_velocity_with_predict(self, mock_open_s, mock_show, mock_close, mock_savefig,
mock_ylabel, mock_xlabel, mock_grid,
mock_legend, mock_title, mock_plot,
mock_figure):
config = PlotConfig(
data=self.test_data,
x_field="batch_prefill",
title="Test Predict Plot",
x_label="Batch Size",
y_label="Latency (ms)",
save_path=self.save_path
)
predict_data = {
State(batch_prefill=1): [11.0, 11.5, 12.0],
State(batch_prefill=2): [21.0, 21.5, 22.0],
}
mock_file = MagicMock()
mock_open_s.return_value.__enter__.return_value = mock_file
AnalysisState.plot_input_velocity_with_predict(config, predict_data)
self.assertEqual(mock_figure.call_count, 1)
self.assertEqual(mock_plot.call_count, 6)
mock_title.assert_called_once_with("Test Predict Plot")
mock_xlabel.assert_called_once_with("Batch Size")
mock_ylabel.assert_called_once_with("Latency (ms)")
mock_legend.assert_called_once()
mock_grid.assert_called_once()
expected_save_path = self.save_path.joinpath(f"Batch Size_Latency (ms)_Test Predict Plot.png")
mock_savefig.assert_called_once_with(expected_save_path)
mock_close.assert_called_once()
write_calls = mock_file.write.call_args_list
self.assertEqual(len(write_calls), 11)
self.assertEqual(write_calls[0][0][0], 'mean\n')
import json
mean_data = json.loads(write_calls[1][0][0])
self.assertIsInstance(mean_data, list)
config.save_path = None
mock_plot.reset_mock()
mock_savefig.reset_mock()
mock_close.reset_mock()
mock_show.reset_mock()
mock_open_s.reset_mock()
AnalysisState.plot_input_velocity_with_predict(config, predict_data)
mock_show.assert_called_once()
mock_savefig.assert_not_called()
mock_close.assert_not_called()
mock_open_s.assert_not_called()
config.x_label = None
config.y_label = None
config.save_path = self.save_path
mock_xlabel.reset_mock()
mock_ylabel.reset_mock()
AnalysisState.plot_input_velocity_with_predict(config, predict_data)
mock_xlabel.assert_not_called()
mock_ylabel.assert_not_called()