import unittest
from unittest.mock import patch, MagicMock
from pathlib import Path
import pandas as pd
import numpy as np
from msserviceprofiler.modelevalstate.analysis import (
AnalysisState,
PlotConfig,
State
)
class TestAnalysisState(unittest.TestCase):
def setUp(self):
self.test_data = {
State(batch_prefill=1): [10.0, 10.5, 11.0],
State(batch_prefill=2): [20.0, 20.5, 21.0],
State(batch_decode=3): [30.0, 30.5, 31.0],
State(batch_decode=4): [40.0, 40.5, 41.0],
}
self.single_data = {
State(batch_prefill=1): [10.0],
State(batch_prefill=2): [20.0],
}
self.save_path = Path("/tmp/test_save_path")
self.save_path.mkdir(exist_ok=True, parents=True)
def tearDown(self):
for file in self.save_path.iterdir():
if file.is_file():
file.unlink()
self.save_path.rmdir()
@patch('matplotlib.pyplot.plot')
@patch('matplotlib.pyplot.show')
@patch('matplotlib.pyplot.close')
def test_computer_mean_sigma(self, mock_close, mock_show, mock_plot):
test_data = {
State(batch_prefill=1): [10.0, 10.5, 11.0],
State(batch_prefill=1, batch_decode=10): [10.1, 10.6, 11.1],
State(batch_prefill=2): [20.0, 20.5, 21.0],
State(batch_prefill=2, batch_decode=20): [20.1, 20.6, 21.1],
}
group1_data = [10.0, 10.5, 11.0, 10.1, 10.6, 11.1]
group2_data = [20.0, 20.5, 21.0, 20.1, 20.6, 21.1]
expected_mean1 = np.mean(group1_data)
expected_mean2 = np.mean(group2_data)
x, mean, pos_sigma, neg_sigma = AnalysisState.computer_mean_sigma(
test_data, "batch_prefill"
)
self.assertIsInstance(x, list)
self.assertIsInstance(mean, list)
self.assertIsInstance(pos_sigma, list)
self.assertIsInstance(neg_sigma, list)
self.assertEqual(len(x), 2)
self.assertAlmostEqual(mean[0], expected_mean1, places=2)
self.assertAlmostEqual(mean[1], expected_mean2, places=2)
@patch('matplotlib.pyplot.plot')
@patch('matplotlib.pyplot.legend')
@patch('matplotlib.pyplot.grid')
@patch('matplotlib.pyplot.title')
@patch('matplotlib.pyplot.xlabel')
@patch('matplotlib.pyplot.ylabel')
@patch('matplotlib.pyplot.savefig')
@patch('matplotlib.pyplot.close')
def test_plot_input_velocity(self, mock_close, mock_savefig, mock_ylabel,
mock_xlabel, mock_title, mock_grid, mock_legend,
mock_plot):
config = PlotConfig(
data=self.test_data,
x_field="batch_prefill",
title="Test Plot",
x_label="Batch Size",
y_label="Latency (ms)",
save_path=str(self.save_path)
)
AnalysisState.plot_input_velocity(config)
self.assertEqual(mock_plot.call_count, 3)
mock_title.assert_called_once_with("Test Plot")
mock_xlabel.assert_called_once_with("Batch Size")
mock_ylabel.assert_called_once_with("Latency (ms)")
mock_legend.assert_called_once()
mock_grid.assert_called_once()
mock_savefig.assert_called_once()
mock_close.assert_called_once()
config.save_path = None
mock_plot.reset_mock()
mock_show = MagicMock()
with patch('matplotlib.pyplot.show', mock_show):
AnalysisState.plot_input_velocity(config)
mock_show.assert_called_once()
@patch('matplotlib.pyplot.figure')
@patch('matplotlib.pyplot.scatter')
@patch('matplotlib.pyplot.title')
@patch('matplotlib.pyplot.xlabel')
@patch('matplotlib.pyplot.ylabel')
@patch('matplotlib.pyplot.legend')
@patch('matplotlib.pyplot.savefig')
@patch('matplotlib.pyplot.close')
@patch('matplotlib.pyplot.show')
def test_plot_pred_and_real(self, mock_show, mock_close, mock_savefig,
mock_legend, mock_ylabel, mock_xlabel,
mock_title, mock_scatter, mock_figure):
pred = [1.1, 2.1, 3.1]
real = [1.0, 2.0, 3.0]
AnalysisState.plot_pred_and_real(pred, real, self.save_path)
self.assertEqual(mock_scatter.call_count, 2)
mock_title.assert_called_once_with("predict value and real value")
mock_xlabel.assert_called_once_with("index")
mock_ylabel.assert_called_once_with("value")
mock_legend.assert_called_once()
mock_savefig.assert_called_once_with(self.save_path / "predict value and real value.png")
mock_close.assert_called_once()
mock_scatter.reset_mock()
mock_savefig.reset_mock()
AnalysisState.plot_pred_and_real(pred, real, None)
mock_show.assert_called_once()
def test_std_calculations(self):
test_data = {
State(batch_prefill=1): [1.0, 2.0, 3.0]
}
x, mean, pos_sigma, neg_sigma = AnalysisState.computer_mean_sigma(
test_data, "batch_prefill"
)
self.assertAlmostEqual(mean[0], 2.0, places=1)
self.assertAlmostEqual(pos_sigma[0], 3.0, places=1)
test_single_point = {
State(batch_prefill=1): [1.0]
}
x, mean, pos_sigma, neg_sigma = AnalysisState.computer_mean_sigma(
test_single_point, "batch_prefill"
)
self.assertEqual(mean[0], 1.0)
self.assertEqual(pos_sigma[0], 1.0)
self.assertEqual(neg_sigma[0], 1.0)